LCOV - code coverage report
Current view: top level - proxy/src/scram - threadpool.rs (source / functions) Coverage Total Hit
Test: 2aa98e37cd3250b9a68c97ef6050b16fe702ab33.info Lines: 93.7 % 223 209
Test Date: 2024-08-29 11:33:10 Functions: 93.3 % 15 14

            Line data    Source code
       1              : //! Custom threadpool implementation for password hashing.
       2              : //!
       3              : //! Requirements:
       4              : //! 1. Fairness per endpoint.
       5              : //! 2. Yield support for high iteration counts.
       6              : 
       7              : use std::sync::{
       8              :     atomic::{AtomicU64, Ordering},
       9              :     Arc,
      10              : };
      11              : 
      12              : use crossbeam_deque::{Injector, Stealer, Worker};
      13              : use itertools::Itertools;
      14              : use parking_lot::{Condvar, Mutex};
      15              : use rand::Rng;
      16              : use rand::{rngs::SmallRng, SeedableRng};
      17              : use tokio::sync::oneshot;
      18              : 
      19              : use crate::{
      20              :     intern::EndpointIdInt,
      21              :     metrics::{ThreadPoolMetrics, ThreadPoolWorkerId},
      22              :     scram::countmin::CountMinSketch,
      23              : };
      24              : 
      25              : use super::pbkdf2::Pbkdf2;
      26              : 
      27              : pub struct ThreadPool {
      28              :     queue: Injector<JobSpec>,
      29              :     stealers: Vec<Stealer<JobSpec>>,
      30              :     parkers: Vec<(Condvar, Mutex<ThreadState>)>,
      31              :     /// bitpacked representation.
      32              :     /// lower 8 bits = number of sleeping threads
      33              :     /// next 8 bits = number of idle threads (searching for work)
      34              :     counters: AtomicU64,
      35              : 
      36              :     pub metrics: Arc<ThreadPoolMetrics>,
      37              : }
      38              : 
      39              : #[derive(PartialEq)]
      40              : enum ThreadState {
      41              :     Parked,
      42              :     Active,
      43              : }
      44              : 
      45              : impl ThreadPool {
      46           36 :     pub fn new(n_workers: u8) -> Arc<Self> {
      47           36 :         let workers = (0..n_workers).map(|_| Worker::new_fifo()).collect_vec();
      48           36 :         let stealers = workers.iter().map(|w| w.stealer()).collect_vec();
      49           36 : 
      50           36 :         let parkers = (0..n_workers)
      51           36 :             .map(|_| (Condvar::new(), Mutex::new(ThreadState::Active)))
      52           36 :             .collect_vec();
      53           36 : 
      54           36 :         let pool = Arc::new(Self {
      55           36 :             queue: Injector::new(),
      56           36 :             stealers,
      57           36 :             parkers,
      58           36 :             // threads start searching for work
      59           36 :             counters: AtomicU64::new((n_workers as u64) << 8),
      60           36 :             metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
      61           36 :         });
      62              : 
      63           36 :         for (i, worker) in workers.into_iter().enumerate() {
      64           36 :             let pool = Arc::clone(&pool);
      65           36 :             std::thread::spawn(move || thread_rt(pool, worker, i));
      66           36 :         }
      67              : 
      68           36 :         pool
      69           36 :     }
      70              : 
      71           30 :     pub(crate) fn spawn_job(
      72           30 :         &self,
      73           30 :         endpoint: EndpointIdInt,
      74           30 :         pbkdf2: Pbkdf2,
      75           30 :     ) -> oneshot::Receiver<[u8; 32]> {
      76           30 :         let (tx, rx) = oneshot::channel();
      77           30 : 
      78           30 :         let queue_was_empty = self.queue.is_empty();
      79           30 : 
      80           30 :         self.metrics.injector_queue_depth.inc();
      81           30 :         self.queue.push(JobSpec {
      82           30 :             response: tx,
      83           30 :             pbkdf2,
      84           30 :             endpoint,
      85           30 :         });
      86           30 : 
      87           30 :         // inspired from <https://github.com/rayon-rs/rayon/blob/3e3962cb8f7b50773bcc360b48a7a674a53a2c77/rayon-core/src/sleep/mod.rs#L242>
      88           30 :         let counts = self.counters.load(Ordering::SeqCst);
      89           30 :         let num_awake_but_idle = (counts >> 8) & 0xff;
      90           30 :         let num_sleepers = counts & 0xff;
      91           30 : 
      92           30 :         // If the queue is non-empty, then we always wake up a worker
      93           30 :         // -- clearly the existing idle jobs aren't enough. Otherwise,
      94           30 :         // check to see if we have enough idle workers.
      95           30 :         if !queue_was_empty || num_awake_but_idle == 0 {
      96           30 :             let num_to_wake = Ord::min(1, num_sleepers);
      97           30 :             self.wake_any_threads(num_to_wake);
      98           30 :         }
      99              : 
     100           30 :         rx
     101           30 :     }
     102              : 
     103              :     #[cold]
     104           30 :     fn wake_any_threads(&self, mut num_to_wake: u64) {
     105           30 :         if num_to_wake > 0 {
     106           30 :             for i in 0..self.parkers.len() {
     107           30 :                 if self.wake_specific_thread(i) {
     108           30 :                     num_to_wake -= 1;
     109           30 :                     if num_to_wake == 0 {
     110           30 :                         return;
     111            0 :                     }
     112            0 :                 }
     113              :             }
     114            0 :         }
     115           30 :     }
     116              : 
     117           30 :     fn wake_specific_thread(&self, index: usize) -> bool {
     118           30 :         let (condvar, lock) = &self.parkers[index];
     119           30 : 
     120           30 :         let mut state = lock.lock();
     121           30 :         if *state == ThreadState::Parked {
     122           30 :             condvar.notify_one();
     123           30 : 
     124           30 :             // When the thread went to sleep, it will have incremented
     125           30 :             // this value. When we wake it, its our job to decrement
     126           30 :             // it. We could have the thread do it, but that would
     127           30 :             // introduce a delay between when the thread was
     128           30 :             // *notified* and when this counter was decremented. That
     129           30 :             // might mislead people with new work into thinking that
     130           30 :             // there are sleeping threads that they should try to
     131           30 :             // wake, when in fact there is nothing left for them to
     132           30 :             // do.
     133           30 :             self.counters.fetch_sub(1, Ordering::SeqCst);
     134           30 :             *state = ThreadState::Active;
     135           30 : 
     136           30 :             true
     137              :         } else {
     138            0 :             false
     139              :         }
     140           30 :     }
     141              : 
     142           59 :     fn steal(&self, rng: &mut impl Rng, skip: usize, worker: &Worker<JobSpec>) -> Option<JobSpec> {
     143           59 :         // announce thread as idle
     144           59 :         self.counters.fetch_add(256, Ordering::SeqCst);
     145              : 
     146              :         // try steal from the global queue
     147           59 :         loop {
     148           59 :             match self.queue.steal_batch_and_pop(worker) {
     149           30 :                 crossbeam_deque::Steal::Success(job) => {
     150           30 :                     self.metrics
     151           30 :                         .injector_queue_depth
     152           30 :                         .set(self.queue.len() as i64);
     153           30 :                     // no longer idle
     154           30 :                     self.counters.fetch_sub(256, Ordering::SeqCst);
     155           30 :                     return Some(job);
     156              :                 }
     157            0 :                 crossbeam_deque::Steal::Retry => continue,
     158           29 :                 crossbeam_deque::Steal::Empty => break,
     159              :             }
     160              :         }
     161              : 
     162              :         // try steal from our neighbours
     163           29 :         loop {
     164           29 :             let mut retry = false;
     165           29 :             let start = rng.gen_range(0..self.stealers.len());
     166           29 :             let job = (start..self.stealers.len())
     167           29 :                 .chain(0..start)
     168           29 :                 .filter(|i| *i != skip)
     169           29 :                 .find_map(
     170           29 :                     |victim| match self.stealers[victim].steal_batch_and_pop(worker) {
     171            0 :                         crossbeam_deque::Steal::Success(job) => Some(job),
     172            0 :                         crossbeam_deque::Steal::Empty => None,
     173              :                         crossbeam_deque::Steal::Retry => {
     174            0 :                             retry = true;
     175            0 :                             None
     176              :                         }
     177           29 :                     },
     178           29 :                 );
     179           29 :             if job.is_some() {
     180              :                 // no longer idle
     181            0 :                 self.counters.fetch_sub(256, Ordering::SeqCst);
     182            0 :                 return job;
     183           29 :             }
     184           29 :             if !retry {
     185           29 :                 return None;
     186            0 :             }
     187              :         }
     188           59 :     }
     189              : }
     190              : 
     191           36 : fn thread_rt(pool: Arc<ThreadPool>, worker: Worker<JobSpec>, index: usize) {
     192           36 :     /// interval when we should steal from the global queue
     193           36 :     /// so that tail latencies are managed appropriately
     194           36 :     const STEAL_INTERVAL: usize = 61;
     195           36 : 
     196           36 :     /// How often to reset the sketch values
     197           36 :     const SKETCH_RESET_INTERVAL: usize = 1021;
     198           36 : 
     199           36 :     let mut rng = SmallRng::from_entropy();
     200           36 : 
     201           36 :     // used to determine whether we should temporarily skip tasks for fairness.
     202           36 :     // 99% of estimates will overcount by no more than 4096 samples
     203           36 :     let mut sketch = CountMinSketch::with_params(1.0 / (SKETCH_RESET_INTERVAL as f64), 0.01);
     204           36 : 
     205           36 :     let (condvar, lock) = &pool.parkers[index];
     206              : 
     207           65 :     'wait: loop {
     208           65 :         // wait for notification of work
     209           65 :         {
     210           65 :             let mut lock = lock.lock();
     211           65 : 
     212           65 :             // queue is empty
     213           65 :             pool.metrics
     214           65 :                 .worker_queue_depth
     215           65 :                 .set(ThreadPoolWorkerId(index), 0);
     216           65 : 
     217           65 :             // subtract 1 from idle count, add 1 to sleeping count.
     218           65 :             pool.counters.fetch_sub(255, Ordering::SeqCst);
     219           65 : 
     220           65 :             *lock = ThreadState::Parked;
     221           65 :             condvar.wait(&mut lock);
     222           65 :         }
     223              : 
     224           99 :         for i in 0.. {
     225           99 :             let Some(mut job) = worker
     226           99 :                 .pop()
     227           99 :                 .or_else(|| pool.steal(&mut rng, index, &worker))
     228              :             else {
     229           29 :                 continue 'wait;
     230              :             };
     231              : 
     232           70 :             pool.metrics
     233           70 :                 .worker_queue_depth
     234           70 :                 .set(ThreadPoolWorkerId(index), worker.len() as i64);
     235           70 : 
     236           70 :             // receiver is closed, cancel the task
     237           70 :             if !job.response.is_closed() {
     238           70 :                 let rate = sketch.inc_and_return(&job.endpoint, job.pbkdf2.cost());
     239           70 : 
     240           70 :                 const P: f64 = 2000.0;
     241           70 :                 // probability decreases as rate increases.
     242           70 :                 // lower probability, higher chance of being skipped
     243           70 :                 //
     244           70 :                 // estimates (rate in terms of 4096 rounds):
     245           70 :                 // rate = 0    => probability = 100%
     246           70 :                 // rate = 10   => probability = 71.3%
     247           70 :                 // rate = 50   => probability = 62.1%
     248           70 :                 // rate = 500  => probability = 52.3%
     249           70 :                 // rate = 1021 => probability = 49.8%
     250           70 :                 //
     251           70 :                 // My expectation is that the pool queue will only begin backing up at ~1000rps
     252           70 :                 // in which case the SKETCH_RESET_INTERVAL represents 1 second. Thus, the rates above
     253           70 :                 // are in requests per second.
     254           70 :                 let probability = P.ln() / (P + rate as f64).ln();
     255           70 :                 if pool.queue.len() > 32 || rng.gen_bool(probability) {
     256           65 :                     pool.metrics
     257           65 :                         .worker_task_turns_total
     258           65 :                         .inc(ThreadPoolWorkerId(index));
     259           65 : 
     260           65 :                     match job.pbkdf2.turn() {
     261           30 :                         std::task::Poll::Ready(result) => {
     262           30 :                             let _ = job.response.send(result);
     263           30 :                         }
     264            0 :                         std::task::Poll::Pending => worker.push(job),
     265              :                     }
     266            5 :                 } else {
     267            5 :                     pool.metrics
     268            5 :                         .worker_task_skips_total
     269            5 :                         .inc(ThreadPoolWorkerId(index));
     270            5 : 
     271            5 :                     // skip for now
     272            5 :                     worker.push(job);
     273            5 :                 }
     274            0 :             }
     275              : 
     276              :             // if we get stuck with a few long lived jobs in the queue
     277              :             // it's better to try and steal from the queue too for fairness
     278           35 :             if i % STEAL_INTERVAL == 0 {
     279           29 :                 let _ = pool.queue.steal_batch(&worker);
     280           29 :             }
     281              : 
     282           35 :             if i % SKETCH_RESET_INTERVAL == 0 {
     283           29 :                 sketch.reset();
     284           29 :             }
     285              :         }
     286              :     }
     287              : }
     288              : 
     289              : struct JobSpec {
     290              :     response: oneshot::Sender<[u8; 32]>,
     291              :     pbkdf2: Pbkdf2,
     292              :     endpoint: EndpointIdInt,
     293              : }
     294              : 
     295              : #[cfg(test)]
     296              : mod tests {
     297              :     use crate::EndpointId;
     298              : 
     299              :     use super::*;
     300              : 
     301              :     #[tokio::test]
     302            6 :     async fn hash_is_correct() {
     303            6 :         let pool = ThreadPool::new(1);
     304            6 : 
     305            6 :         let ep = EndpointId::from("foo");
     306            6 :         let ep = EndpointIdInt::from(ep);
     307            6 : 
     308            6 :         let salt = [0x55; 32];
     309            6 :         let actual = pool
     310            6 :             .spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096))
     311            6 :             .await
     312            6 :             .unwrap();
     313            6 : 
     314            6 :         let expected = [
     315            6 :             10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
     316            6 :             178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
     317            6 :         ];
     318            6 :         assert_eq!(actual, expected);
     319            6 :     }
     320              : }
        

Generated by: LCOV version 2.1-beta