LCOV - code coverage report
Current view: top level - proxy/src - batch.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 0.0 % 83 0
Test Date: 2025-07-16 12:29:03 Functions: 0.0 % 9 0

            Line data    Source code
       1              : //! Batch processing system based on intrusive linked lists.
       2              : //!
       3              : //! Enqueuing a batch job requires no allocations, with
       4              : //! direct support for cancelling jobs early.
       5              : use std::collections::BTreeMap;
       6              : use std::pin::pin;
       7              : use std::sync::Mutex;
       8              : 
       9              : use scopeguard::ScopeGuard;
      10              : use tokio::sync::oneshot;
      11              : use tokio::sync::oneshot::error::TryRecvError;
      12              : 
      13              : use crate::ext::LockExt;
      14              : 
      15              : type ProcResult<P> = Result<<P as QueueProcessing>::Res, <P as QueueProcessing>::Err>;
      16              : 
      17              : pub trait QueueProcessing: Send + 'static {
      18              :     type Req: Send + 'static;
      19              :     type Res: Send;
      20              :     type Err: Send + Clone;
      21              : 
      22              :     /// Get the desired batch size.
      23              :     fn batch_size(&self, queue_size: usize) -> usize;
      24              : 
      25              :     /// This applies a full batch of events.
      26              :     /// Must respond with a full batch of replies.
      27              :     ///
      28              :     /// If this apply can error, it's expected that errors be forwarded to each Self::Res.
      29              :     ///
      30              :     /// Batching does not need to happen atomically.
      31              :     fn apply(
      32              :         &mut self,
      33              :         req: Vec<Self::Req>,
      34              :     ) -> impl Future<Output = Result<Vec<Self::Res>, Self::Err>> + Send;
      35              : }
      36              : 
      37              : #[derive(thiserror::Error)]
      38              : pub enum BatchQueueError<E: Clone, C> {
      39              :     #[error(transparent)]
      40              :     Result(E),
      41              :     #[error(transparent)]
      42              :     Cancelled(C),
      43              : }
      44              : 
      45              : pub struct BatchQueue<P: QueueProcessing> {
      46              :     processor: tokio::sync::Mutex<P>,
      47              :     inner: Mutex<BatchQueueInner<P>>,
      48              : }
      49              : 
      50              : struct BatchJob<P: QueueProcessing> {
      51              :     req: P::Req,
      52              :     res: tokio::sync::oneshot::Sender<Result<P::Res, P::Err>>,
      53              : }
      54              : 
      55              : impl<P: QueueProcessing> BatchQueue<P> {
      56            0 :     pub fn new(p: P) -> Self {
      57            0 :         Self {
      58            0 :             processor: tokio::sync::Mutex::new(p),
      59            0 :             inner: Mutex::new(BatchQueueInner {
      60            0 :                 version: 0,
      61            0 :                 queue: BTreeMap::new(),
      62            0 :             }),
      63            0 :         }
      64            0 :     }
      65              : 
      66              :     /// Perform a single request-response process, this may be batched internally.
      67              :     ///
      68              :     /// This function is not cancel safe.
      69            0 :     pub async fn call<R>(
      70            0 :         &self,
      71            0 :         req: P::Req,
      72            0 :         cancelled: impl Future<Output = R>,
      73            0 :     ) -> Result<P::Res, BatchQueueError<P::Err, R>> {
      74            0 :         let (id, mut rx) = self.inner.lock_propagate_poison().register_job(req);
      75              : 
      76            0 :         let mut cancelled = pin!(cancelled);
      77            0 :         let resp: Option<Result<P::Res, P::Err>> = loop {
      78              :             // try become the leader, or try wait for success.
      79            0 :             let mut processor = tokio::select! {
      80              :                 // try become leader.
      81            0 :                 p = self.processor.lock() => p,
      82              :                 // wait for success.
      83            0 :                 resp = &mut rx => break resp.ok(),
      84              :                 // wait for cancellation.
      85            0 :                 cancel = cancelled.as_mut() => {
      86            0 :                     let mut inner = self.inner.lock_propagate_poison();
      87            0 :                     if inner.queue.remove(&id).is_some() {
      88            0 :                         tracing::warn!("batched task cancelled before completion");
      89            0 :                     }
      90            0 :                     return Err(BatchQueueError::Cancelled(cancel));
      91              :                 },
      92              :             };
      93              : 
      94            0 :             tracing::debug!(id, "batch: became leader");
      95            0 :             let (reqs, resps) = self.inner.lock_propagate_poison().get_batch(&processor);
      96              : 
      97              :             // snitch incase the task gets cancelled.
      98            0 :             let cancel_safety = scopeguard::guard((), |()| {
      99            0 :                 if !std::thread::panicking() {
     100            0 :                     tracing::error!(
     101              :                         id,
     102            0 :                         "batch: leader cancelled, despite not being cancellation safe"
     103              :                     );
     104            0 :                 }
     105            0 :             });
     106              : 
     107              :             // apply a batch.
     108              :             // if this is cancelled, jobs will not be completed and will panic.
     109            0 :             let values = processor.apply(reqs).await;
     110              : 
     111              :             // good: we didn't get cancelled.
     112            0 :             ScopeGuard::into_inner(cancel_safety);
     113              : 
     114            0 :             match values {
     115            0 :                 Ok(values) => {
     116            0 :                     if values.len() != resps.len() {
     117            0 :                         tracing::error!(
     118            0 :                             "batch: invalid response size, expected={}, got={}",
     119            0 :                             resps.len(),
     120            0 :                             values.len()
     121              :                         );
     122            0 :                     }
     123              : 
     124              :                     // send response values.
     125            0 :                     for (tx, value) in std::iter::zip(resps, values) {
     126            0 :                         if tx.send(Ok(value)).is_err() {
     127            0 :                             // receiver hung up but that's fine.
     128            0 :                         }
     129              :                     }
     130              :                 }
     131              : 
     132            0 :                 Err(err) => {
     133            0 :                     for tx in resps {
     134            0 :                         if tx.send(Err(err.clone())).is_err() {
     135            0 :                             // receiver hung up but that's fine.
     136            0 :                         }
     137              :                     }
     138              :                 }
     139              :             }
     140              : 
     141            0 :             match rx.try_recv() {
     142            0 :                 Ok(resp) => break Some(resp),
     143            0 :                 Err(TryRecvError::Closed) => break None,
     144              :                 // edge case - there was a race condition where
     145              :                 // we became the leader but were not in the batch.
     146              :                 //
     147              :                 // Example:
     148              :                 // thread 1: register job id=1
     149              :                 // thread 2: register job id=2
     150              :                 // thread 2: processor.lock().await
     151              :                 // thread 1: processor.lock().await
     152              :                 // thread 2: becomes leader, batch_size=1, jobs=[1].
     153            0 :                 Err(TryRecvError::Empty) => {}
     154              :             }
     155              :         };
     156              : 
     157            0 :         tracing::debug!(id, "batch: job completed");
     158              : 
     159            0 :         resp.expect("no response found. batch processer should not panic")
     160            0 :             .map_err(BatchQueueError::Result)
     161            0 :     }
     162              : }
     163              : 
     164              : struct BatchQueueInner<P: QueueProcessing> {
     165              :     version: u64,
     166              :     queue: BTreeMap<u64, BatchJob<P>>,
     167              : }
     168              : 
     169              : impl<P: QueueProcessing> BatchQueueInner<P> {
     170            0 :     fn register_job(&mut self, req: P::Req) -> (u64, oneshot::Receiver<ProcResult<P>>) {
     171            0 :         let (tx, rx) = oneshot::channel();
     172              : 
     173            0 :         let id = self.version;
     174              : 
     175              :         // Overflow concern:
     176              :         // This is a u64, and we might enqueue 2^16 tasks per second.
     177              :         // This gives us 2^48 seconds (9 million years).
     178              :         // Even if this does overflow, it will not break, but some
     179              :         // jobs with the higher version might never get prioritised.
     180            0 :         self.version += 1;
     181              : 
     182            0 :         self.queue.insert(id, BatchJob { req, res: tx });
     183              : 
     184            0 :         tracing::debug!(id, "batch: registered job in the queue");
     185              : 
     186            0 :         (id, rx)
     187            0 :     }
     188              : 
     189            0 :     fn get_batch(&mut self, p: &P) -> (Vec<P::Req>, Vec<oneshot::Sender<ProcResult<P>>>) {
     190            0 :         let batch_size = p.batch_size(self.queue.len());
     191            0 :         let mut reqs = Vec::with_capacity(batch_size);
     192            0 :         let mut resps = Vec::with_capacity(batch_size);
     193            0 :         let mut ids = Vec::with_capacity(batch_size);
     194              : 
     195            0 :         while reqs.len() < batch_size {
     196            0 :             let Some((id, job)) = self.queue.pop_first() else {
     197            0 :                 break;
     198              :             };
     199            0 :             reqs.push(job.req);
     200            0 :             resps.push(job.res);
     201            0 :             ids.push(id);
     202              :         }
     203              : 
     204            0 :         tracing::debug!(ids=?ids, "batch: acquired jobs");
     205              : 
     206            0 :         (reqs, resps)
     207            0 :     }
     208              : }
        

Generated by: LCOV version 2.1-beta