Line data Source code
1 : //! Batch processing system based on intrusive linked lists.
2 : //!
3 : //! Enqueuing a batch job requires no allocations, with
4 : //! direct support for cancelling jobs early.
5 : use std::collections::BTreeMap;
6 : use std::pin::pin;
7 : use std::sync::Mutex;
8 :
9 : use scopeguard::ScopeGuard;
10 : use tokio::sync::oneshot;
11 : use tokio::sync::oneshot::error::TryRecvError;
12 :
13 : use crate::ext::LockExt;
14 :
15 : type ProcResult<P> = Result<<P as QueueProcessing>::Res, <P as QueueProcessing>::Err>;
16 :
17 : pub trait QueueProcessing: Send + 'static {
18 : type Req: Send + 'static;
19 : type Res: Send;
20 : type Err: Send + Clone;
21 :
22 : /// Get the desired batch size.
23 : fn batch_size(&self, queue_size: usize) -> usize;
24 :
25 : /// This applies a full batch of events.
26 : /// Must respond with a full batch of replies.
27 : ///
28 : /// If this apply can error, it's expected that errors be forwarded to each Self::Res.
29 : ///
30 : /// Batching does not need to happen atomically.
31 : fn apply(
32 : &mut self,
33 : req: Vec<Self::Req>,
34 : ) -> impl Future<Output = Result<Vec<Self::Res>, Self::Err>> + Send;
35 : }
36 :
37 : #[derive(thiserror::Error)]
38 : pub enum BatchQueueError<E: Clone, C> {
39 : #[error(transparent)]
40 : Result(E),
41 : #[error(transparent)]
42 : Cancelled(C),
43 : }
44 :
45 : pub struct BatchQueue<P: QueueProcessing> {
46 : processor: tokio::sync::Mutex<P>,
47 : inner: Mutex<BatchQueueInner<P>>,
48 : }
49 :
50 : struct BatchJob<P: QueueProcessing> {
51 : req: P::Req,
52 : res: tokio::sync::oneshot::Sender<Result<P::Res, P::Err>>,
53 : }
54 :
55 : impl<P: QueueProcessing> BatchQueue<P> {
56 0 : pub fn new(p: P) -> Self {
57 0 : Self {
58 0 : processor: tokio::sync::Mutex::new(p),
59 0 : inner: Mutex::new(BatchQueueInner {
60 0 : version: 0,
61 0 : queue: BTreeMap::new(),
62 0 : }),
63 0 : }
64 0 : }
65 :
66 : /// Perform a single request-response process, this may be batched internally.
67 : ///
68 : /// This function is not cancel safe.
69 0 : pub async fn call<R>(
70 0 : &self,
71 0 : req: P::Req,
72 0 : cancelled: impl Future<Output = R>,
73 0 : ) -> Result<P::Res, BatchQueueError<P::Err, R>> {
74 0 : let (id, mut rx) = self.inner.lock_propagate_poison().register_job(req);
75 :
76 0 : let mut cancelled = pin!(cancelled);
77 0 : let resp: Option<Result<P::Res, P::Err>> = loop {
78 : // try become the leader, or try wait for success.
79 0 : let mut processor = tokio::select! {
80 : // try become leader.
81 0 : p = self.processor.lock() => p,
82 : // wait for success.
83 0 : resp = &mut rx => break resp.ok(),
84 : // wait for cancellation.
85 0 : cancel = cancelled.as_mut() => {
86 0 : let mut inner = self.inner.lock_propagate_poison();
87 0 : if inner.queue.remove(&id).is_some() {
88 0 : tracing::warn!("batched task cancelled before completion");
89 0 : }
90 0 : return Err(BatchQueueError::Cancelled(cancel));
91 : },
92 : };
93 :
94 0 : tracing::debug!(id, "batch: became leader");
95 0 : let (reqs, resps) = self.inner.lock_propagate_poison().get_batch(&processor);
96 :
97 : // snitch incase the task gets cancelled.
98 0 : let cancel_safety = scopeguard::guard((), |()| {
99 0 : if !std::thread::panicking() {
100 0 : tracing::error!(
101 : id,
102 0 : "batch: leader cancelled, despite not being cancellation safe"
103 : );
104 0 : }
105 0 : });
106 :
107 : // apply a batch.
108 : // if this is cancelled, jobs will not be completed and will panic.
109 0 : let values = processor.apply(reqs).await;
110 :
111 : // good: we didn't get cancelled.
112 0 : ScopeGuard::into_inner(cancel_safety);
113 :
114 0 : match values {
115 0 : Ok(values) => {
116 0 : if values.len() != resps.len() {
117 0 : tracing::error!(
118 0 : "batch: invalid response size, expected={}, got={}",
119 0 : resps.len(),
120 0 : values.len()
121 : );
122 0 : }
123 :
124 : // send response values.
125 0 : for (tx, value) in std::iter::zip(resps, values) {
126 0 : if tx.send(Ok(value)).is_err() {
127 0 : // receiver hung up but that's fine.
128 0 : }
129 : }
130 : }
131 :
132 0 : Err(err) => {
133 0 : for tx in resps {
134 0 : if tx.send(Err(err.clone())).is_err() {
135 0 : // receiver hung up but that's fine.
136 0 : }
137 : }
138 : }
139 : }
140 :
141 0 : match rx.try_recv() {
142 0 : Ok(resp) => break Some(resp),
143 0 : Err(TryRecvError::Closed) => break None,
144 : // edge case - there was a race condition where
145 : // we became the leader but were not in the batch.
146 : //
147 : // Example:
148 : // thread 1: register job id=1
149 : // thread 2: register job id=2
150 : // thread 2: processor.lock().await
151 : // thread 1: processor.lock().await
152 : // thread 2: becomes leader, batch_size=1, jobs=[1].
153 0 : Err(TryRecvError::Empty) => {}
154 : }
155 : };
156 :
157 0 : tracing::debug!(id, "batch: job completed");
158 :
159 0 : resp.expect("no response found. batch processer should not panic")
160 0 : .map_err(BatchQueueError::Result)
161 0 : }
162 : }
163 :
164 : struct BatchQueueInner<P: QueueProcessing> {
165 : version: u64,
166 : queue: BTreeMap<u64, BatchJob<P>>,
167 : }
168 :
169 : impl<P: QueueProcessing> BatchQueueInner<P> {
170 0 : fn register_job(&mut self, req: P::Req) -> (u64, oneshot::Receiver<ProcResult<P>>) {
171 0 : let (tx, rx) = oneshot::channel();
172 :
173 0 : let id = self.version;
174 :
175 : // Overflow concern:
176 : // This is a u64, and we might enqueue 2^16 tasks per second.
177 : // This gives us 2^48 seconds (9 million years).
178 : // Even if this does overflow, it will not break, but some
179 : // jobs with the higher version might never get prioritised.
180 0 : self.version += 1;
181 :
182 0 : self.queue.insert(id, BatchJob { req, res: tx });
183 :
184 0 : tracing::debug!(id, "batch: registered job in the queue");
185 :
186 0 : (id, rx)
187 0 : }
188 :
189 0 : fn get_batch(&mut self, p: &P) -> (Vec<P::Req>, Vec<oneshot::Sender<ProcResult<P>>>) {
190 0 : let batch_size = p.batch_size(self.queue.len());
191 0 : let mut reqs = Vec::with_capacity(batch_size);
192 0 : let mut resps = Vec::with_capacity(batch_size);
193 0 : let mut ids = Vec::with_capacity(batch_size);
194 :
195 0 : while reqs.len() < batch_size {
196 0 : let Some((id, job)) = self.queue.pop_first() else {
197 0 : break;
198 : };
199 0 : reqs.push(job.req);
200 0 : resps.push(job.res);
201 0 : ids.push(id);
202 : }
203 :
204 0 : tracing::debug!(ids=?ids, "batch: acquired jobs");
205 :
206 0 : (reqs, resps)
207 0 : }
208 : }
|