LCOV - code coverage report
Current view: top level - proxy/src - batch.rs (source / functions) Coverage Total Hit
Test: 157166bf1e7b60cf936c3c96f6e44d24268705a4.info Lines: 0.0 % 75 0
Test Date: 2025-07-08 19:05:57 Functions: 0.0 % 9 0

            Line data    Source code
       1              : //! Batch processing system based on intrusive linked lists.
       2              : //!
       3              : //! Enqueuing a batch job requires no allocations, with
       4              : //! direct support for cancelling jobs early.
       5              : use std::collections::BTreeMap;
       6              : use std::pin::pin;
       7              : use std::sync::Mutex;
       8              : 
       9              : use scopeguard::ScopeGuard;
      10              : use tokio::sync::oneshot::error::TryRecvError;
      11              : 
      12              : use crate::ext::LockExt;
      13              : 
      14              : pub trait QueueProcessing: Send + 'static {
      15              :     type Req: Send + 'static;
      16              :     type Res: Send;
      17              : 
      18              :     /// Get the desired batch size.
      19              :     fn batch_size(&self, queue_size: usize) -> usize;
      20              : 
      21              :     /// This applies a full batch of events.
      22              :     /// Must respond with a full batch of replies.
      23              :     ///
      24              :     /// If this apply can error, it's expected that errors be forwarded to each Self::Res.
      25              :     ///
      26              :     /// Batching does not need to happen atomically.
      27              :     fn apply(&mut self, req: Vec<Self::Req>) -> impl Future<Output = Vec<Self::Res>> + Send;
      28              : }
      29              : 
      30              : pub struct BatchQueue<P: QueueProcessing> {
      31              :     processor: tokio::sync::Mutex<P>,
      32              :     inner: Mutex<BatchQueueInner<P>>,
      33              : }
      34              : 
      35              : struct BatchJob<P: QueueProcessing> {
      36              :     req: P::Req,
      37              :     res: tokio::sync::oneshot::Sender<P::Res>,
      38              : }
      39              : 
      40              : impl<P: QueueProcessing> BatchQueue<P> {
      41            0 :     pub fn new(p: P) -> Self {
      42            0 :         Self {
      43            0 :             processor: tokio::sync::Mutex::new(p),
      44            0 :             inner: Mutex::new(BatchQueueInner {
      45            0 :                 version: 0,
      46            0 :                 queue: BTreeMap::new(),
      47            0 :             }),
      48            0 :         }
      49            0 :     }
      50              : 
      51              :     /// Perform a single request-response process, this may be batched internally.
      52              :     ///
      53              :     /// This function is not cancel safe.
      54            0 :     pub async fn call<R>(
      55            0 :         &self,
      56            0 :         req: P::Req,
      57            0 :         cancelled: impl Future<Output = R>,
      58            0 :     ) -> Result<P::Res, R> {
      59            0 :         let (id, mut rx) = self.inner.lock_propagate_poison().register_job(req);
      60              : 
      61            0 :         let mut cancelled = pin!(cancelled);
      62            0 :         let resp = loop {
      63              :             // try become the leader, or try wait for success.
      64            0 :             let mut processor = tokio::select! {
      65              :                 // try become leader.
      66            0 :                 p = self.processor.lock() => p,
      67              :                 // wait for success.
      68            0 :                 resp = &mut rx => break resp.ok(),
      69              :                 // wait for cancellation.
      70            0 :                 cancel = cancelled.as_mut() => {
      71            0 :                     let mut inner = self.inner.lock_propagate_poison();
      72            0 :                     if inner.queue.remove(&id).is_some() {
      73            0 :                         tracing::warn!("batched task cancelled before completion");
      74            0 :                     }
      75            0 :                     return Err(cancel);
      76              :                 },
      77              :             };
      78              : 
      79            0 :             tracing::debug!(id, "batch: became leader");
      80            0 :             let (reqs, resps) = self.inner.lock_propagate_poison().get_batch(&processor);
      81              : 
      82              :             // snitch incase the task gets cancelled.
      83            0 :             let cancel_safety = scopeguard::guard((), |()| {
      84            0 :                 if !std::thread::panicking() {
      85            0 :                     tracing::error!(
      86              :                         id,
      87            0 :                         "batch: leader cancelled, despite not being cancellation safe"
      88              :                     );
      89            0 :                 }
      90            0 :             });
      91              : 
      92              :             // apply a batch.
      93              :             // if this is cancelled, jobs will not be completed and will panic.
      94            0 :             let values = processor.apply(reqs).await;
      95              : 
      96              :             // good: we didn't get cancelled.
      97            0 :             ScopeGuard::into_inner(cancel_safety);
      98              : 
      99            0 :             if values.len() != resps.len() {
     100            0 :                 tracing::error!(
     101            0 :                     "batch: invalid response size, expected={}, got={}",
     102            0 :                     resps.len(),
     103            0 :                     values.len()
     104              :                 );
     105            0 :             }
     106              : 
     107              :             // send response values.
     108            0 :             for (tx, value) in std::iter::zip(resps, values) {
     109            0 :                 if tx.send(value).is_err() {
     110            0 :                     // receiver hung up but that's fine.
     111            0 :                 }
     112              :             }
     113              : 
     114            0 :             match rx.try_recv() {
     115            0 :                 Ok(resp) => break Some(resp),
     116            0 :                 Err(TryRecvError::Closed) => break None,
     117              :                 // edge case - there was a race condition where
     118              :                 // we became the leader but were not in the batch.
     119              :                 //
     120              :                 // Example:
     121              :                 // thread 1: register job id=1
     122              :                 // thread 2: register job id=2
     123              :                 // thread 2: processor.lock().await
     124              :                 // thread 1: processor.lock().await
     125              :                 // thread 2: becomes leader, batch_size=1, jobs=[1].
     126            0 :                 Err(TryRecvError::Empty) => {}
     127              :             }
     128              :         };
     129              : 
     130            0 :         tracing::debug!(id, "batch: job completed");
     131              : 
     132            0 :         Ok(resp.expect("no response found. batch processer should not panic"))
     133            0 :     }
     134              : }
     135              : 
     136              : struct BatchQueueInner<P: QueueProcessing> {
     137              :     version: u64,
     138              :     queue: BTreeMap<u64, BatchJob<P>>,
     139              : }
     140              : 
     141              : impl<P: QueueProcessing> BatchQueueInner<P> {
     142            0 :     fn register_job(&mut self, req: P::Req) -> (u64, tokio::sync::oneshot::Receiver<P::Res>) {
     143            0 :         let (tx, rx) = tokio::sync::oneshot::channel();
     144              : 
     145            0 :         let id = self.version;
     146              : 
     147              :         // Overflow concern:
     148              :         // This is a u64, and we might enqueue 2^16 tasks per second.
     149              :         // This gives us 2^48 seconds (9 million years).
     150              :         // Even if this does overflow, it will not break, but some
     151              :         // jobs with the higher version might never get prioritised.
     152            0 :         self.version += 1;
     153              : 
     154            0 :         self.queue.insert(id, BatchJob { req, res: tx });
     155              : 
     156            0 :         tracing::debug!(id, "batch: registered job in the queue");
     157              : 
     158            0 :         (id, rx)
     159            0 :     }
     160              : 
     161            0 :     fn get_batch(&mut self, p: &P) -> (Vec<P::Req>, Vec<tokio::sync::oneshot::Sender<P::Res>>) {
     162            0 :         let batch_size = p.batch_size(self.queue.len());
     163            0 :         let mut reqs = Vec::with_capacity(batch_size);
     164            0 :         let mut resps = Vec::with_capacity(batch_size);
     165            0 :         let mut ids = Vec::with_capacity(batch_size);
     166              : 
     167            0 :         while reqs.len() < batch_size {
     168            0 :             let Some((id, job)) = self.queue.pop_first() else {
     169            0 :                 break;
     170              :             };
     171            0 :             reqs.push(job.req);
     172            0 :             resps.push(job.res);
     173            0 :             ids.push(id);
     174              :         }
     175              : 
     176            0 :         tracing::debug!(ids=?ids, "batch: acquired jobs");
     177              : 
     178            0 :         (reqs, resps)
     179            0 :     }
     180              : }
        

Generated by: LCOV version 2.1-beta