Line data Source code
1 : //! Batch processing system based on intrusive linked lists.
2 : //!
3 : //! Enqueuing a batch job requires no allocations, with
4 : //! direct support for cancelling jobs early.
5 : use std::collections::BTreeMap;
6 : use std::pin::pin;
7 : use std::sync::Mutex;
8 :
9 : use scopeguard::ScopeGuard;
10 : use tokio::sync::oneshot::error::TryRecvError;
11 :
12 : use crate::ext::LockExt;
13 :
14 : pub trait QueueProcessing: Send + 'static {
15 : type Req: Send + 'static;
16 : type Res: Send;
17 :
18 : /// Get the desired batch size.
19 : fn batch_size(&self, queue_size: usize) -> usize;
20 :
21 : /// This applies a full batch of events.
22 : /// Must respond with a full batch of replies.
23 : ///
24 : /// If this apply can error, it's expected that errors be forwarded to each Self::Res.
25 : ///
26 : /// Batching does not need to happen atomically.
27 : fn apply(&mut self, req: Vec<Self::Req>) -> impl Future<Output = Vec<Self::Res>> + Send;
28 : }
29 :
30 : pub struct BatchQueue<P: QueueProcessing> {
31 : processor: tokio::sync::Mutex<P>,
32 : inner: Mutex<BatchQueueInner<P>>,
33 : }
34 :
35 : struct BatchJob<P: QueueProcessing> {
36 : req: P::Req,
37 : res: tokio::sync::oneshot::Sender<P::Res>,
38 : }
39 :
40 : impl<P: QueueProcessing> BatchQueue<P> {
41 0 : pub fn new(p: P) -> Self {
42 0 : Self {
43 0 : processor: tokio::sync::Mutex::new(p),
44 0 : inner: Mutex::new(BatchQueueInner {
45 0 : version: 0,
46 0 : queue: BTreeMap::new(),
47 0 : }),
48 0 : }
49 0 : }
50 :
51 : /// Perform a single request-response process, this may be batched internally.
52 : ///
53 : /// This function is not cancel safe.
54 0 : pub async fn call<R>(
55 0 : &self,
56 0 : req: P::Req,
57 0 : cancelled: impl Future<Output = R>,
58 0 : ) -> Result<P::Res, R> {
59 0 : let (id, mut rx) = self.inner.lock_propagate_poison().register_job(req);
60 :
61 0 : let mut cancelled = pin!(cancelled);
62 0 : let resp = loop {
63 : // try become the leader, or try wait for success.
64 0 : let mut processor = tokio::select! {
65 : // try become leader.
66 0 : p = self.processor.lock() => p,
67 : // wait for success.
68 0 : resp = &mut rx => break resp.ok(),
69 : // wait for cancellation.
70 0 : cancel = cancelled.as_mut() => {
71 0 : let mut inner = self.inner.lock_propagate_poison();
72 0 : if inner.queue.remove(&id).is_some() {
73 0 : tracing::warn!("batched task cancelled before completion");
74 0 : }
75 0 : return Err(cancel);
76 : },
77 : };
78 :
79 0 : tracing::debug!(id, "batch: became leader");
80 0 : let (reqs, resps) = self.inner.lock_propagate_poison().get_batch(&processor);
81 :
82 : // snitch incase the task gets cancelled.
83 0 : let cancel_safety = scopeguard::guard((), |()| {
84 0 : if !std::thread::panicking() {
85 0 : tracing::error!(
86 : id,
87 0 : "batch: leader cancelled, despite not being cancellation safe"
88 : );
89 0 : }
90 0 : });
91 :
92 : // apply a batch.
93 : // if this is cancelled, jobs will not be completed and will panic.
94 0 : let values = processor.apply(reqs).await;
95 :
96 : // good: we didn't get cancelled.
97 0 : ScopeGuard::into_inner(cancel_safety);
98 :
99 0 : if values.len() != resps.len() {
100 0 : tracing::error!(
101 0 : "batch: invalid response size, expected={}, got={}",
102 0 : resps.len(),
103 0 : values.len()
104 : );
105 0 : }
106 :
107 : // send response values.
108 0 : for (tx, value) in std::iter::zip(resps, values) {
109 0 : if tx.send(value).is_err() {
110 0 : // receiver hung up but that's fine.
111 0 : }
112 : }
113 :
114 0 : match rx.try_recv() {
115 0 : Ok(resp) => break Some(resp),
116 0 : Err(TryRecvError::Closed) => break None,
117 : // edge case - there was a race condition where
118 : // we became the leader but were not in the batch.
119 : //
120 : // Example:
121 : // thread 1: register job id=1
122 : // thread 2: register job id=2
123 : // thread 2: processor.lock().await
124 : // thread 1: processor.lock().await
125 : // thread 2: becomes leader, batch_size=1, jobs=[1].
126 0 : Err(TryRecvError::Empty) => {}
127 : }
128 : };
129 :
130 0 : tracing::debug!(id, "batch: job completed");
131 :
132 0 : Ok(resp.expect("no response found. batch processer should not panic"))
133 0 : }
134 : }
135 :
136 : struct BatchQueueInner<P: QueueProcessing> {
137 : version: u64,
138 : queue: BTreeMap<u64, BatchJob<P>>,
139 : }
140 :
141 : impl<P: QueueProcessing> BatchQueueInner<P> {
142 0 : fn register_job(&mut self, req: P::Req) -> (u64, tokio::sync::oneshot::Receiver<P::Res>) {
143 0 : let (tx, rx) = tokio::sync::oneshot::channel();
144 :
145 0 : let id = self.version;
146 :
147 : // Overflow concern:
148 : // This is a u64, and we might enqueue 2^16 tasks per second.
149 : // This gives us 2^48 seconds (9 million years).
150 : // Even if this does overflow, it will not break, but some
151 : // jobs with the higher version might never get prioritised.
152 0 : self.version += 1;
153 :
154 0 : self.queue.insert(id, BatchJob { req, res: tx });
155 :
156 0 : tracing::debug!(id, "batch: registered job in the queue");
157 :
158 0 : (id, rx)
159 0 : }
160 :
161 0 : fn get_batch(&mut self, p: &P) -> (Vec<P::Req>, Vec<tokio::sync::oneshot::Sender<P::Res>>) {
162 0 : let batch_size = p.batch_size(self.queue.len());
163 0 : let mut reqs = Vec::with_capacity(batch_size);
164 0 : let mut resps = Vec::with_capacity(batch_size);
165 0 : let mut ids = Vec::with_capacity(batch_size);
166 :
167 0 : while reqs.len() < batch_size {
168 0 : let Some((id, job)) = self.queue.pop_first() else {
169 0 : break;
170 : };
171 0 : reqs.push(job.req);
172 0 : resps.push(job.res);
173 0 : ids.push(id);
174 : }
175 :
176 0 : tracing::debug!(ids=?ids, "batch: acquired jobs");
177 :
178 0 : (reqs, resps)
179 0 : }
180 : }
|