LCOV - code coverage report
Current view: top level - proxy/src/scram - threadpool.rs (source / functions) Coverage Total Hit
Test: e402c46de0a007db6b48dddbde450ddbb92e6ceb.info Lines: 93.7 % 222 208
Test Date: 2024-06-25 10:31:23 Functions: 93.3 % 15 14

            Line data    Source code
       1              : //! Custom threadpool implementation for password hashing.
       2              : //!
       3              : //! Requirements:
       4              : //! 1. Fairness per endpoint.
       5              : //! 2. Yield support for high iteration counts.
       6              : 
       7              : use std::sync::{
       8              :     atomic::{AtomicU64, Ordering},
       9              :     Arc,
      10              : };
      11              : 
      12              : use crossbeam_deque::{Injector, Stealer, Worker};
      13              : use itertools::Itertools;
      14              : use parking_lot::{Condvar, Mutex};
      15              : use rand::Rng;
      16              : use rand::{rngs::SmallRng, SeedableRng};
      17              : use tokio::sync::oneshot;
      18              : 
      19              : use crate::{
      20              :     intern::EndpointIdInt,
      21              :     metrics::{ThreadPoolMetrics, ThreadPoolWorkerId},
      22              :     scram::countmin::CountMinSketch,
      23              : };
      24              : 
      25              : use super::pbkdf2::Pbkdf2;
      26              : 
      27              : pub struct ThreadPool {
      28              :     queue: Injector<JobSpec>,
      29              :     stealers: Vec<Stealer<JobSpec>>,
      30              :     parkers: Vec<(Condvar, Mutex<ThreadState>)>,
      31              :     /// bitpacked representation.
      32              :     /// lower 8 bits = number of sleeping threads
      33              :     /// next 8 bits = number of idle threads (searching for work)
      34              :     counters: AtomicU64,
      35              : 
      36              :     pub metrics: Arc<ThreadPoolMetrics>,
      37              : }
      38              : 
      39              : #[derive(PartialEq)]
      40              : enum ThreadState {
      41              :     Parked,
      42              :     Active,
      43              : }
      44              : 
      45              : impl ThreadPool {
      46           12 :     pub fn new(n_workers: u8) -> Arc<Self> {
      47           12 :         let workers = (0..n_workers).map(|_| Worker::new_fifo()).collect_vec();
      48           12 :         let stealers = workers.iter().map(|w| w.stealer()).collect_vec();
      49           12 : 
      50           12 :         let parkers = (0..n_workers)
      51           12 :             .map(|_| (Condvar::new(), Mutex::new(ThreadState::Active)))
      52           12 :             .collect_vec();
      53           12 : 
      54           12 :         let pool = Arc::new(Self {
      55           12 :             queue: Injector::new(),
      56           12 :             stealers,
      57           12 :             parkers,
      58           12 :             // threads start searching for work
      59           12 :             counters: AtomicU64::new((n_workers as u64) << 8),
      60           12 :             metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
      61           12 :         });
      62              : 
      63           12 :         for (i, worker) in workers.into_iter().enumerate() {
      64           12 :             let pool = Arc::clone(&pool);
      65           12 :             std::thread::spawn(move || thread_rt(pool, worker, i));
      66           12 :         }
      67              : 
      68           12 :         pool
      69           12 :     }
      70              : 
      71           10 :     pub fn spawn_job(
      72           10 :         &self,
      73           10 :         endpoint: EndpointIdInt,
      74           10 :         pbkdf2: Pbkdf2,
      75           10 :     ) -> oneshot::Receiver<[u8; 32]> {
      76           10 :         let (tx, rx) = oneshot::channel();
      77           10 : 
      78           10 :         let queue_was_empty = self.queue.is_empty();
      79           10 : 
      80           10 :         self.metrics.injector_queue_depth.inc();
      81           10 :         self.queue.push(JobSpec {
      82           10 :             response: tx,
      83           10 :             pbkdf2,
      84           10 :             endpoint,
      85           10 :         });
      86           10 : 
      87           10 :         // inspired from <https://github.com/rayon-rs/rayon/blob/3e3962cb8f7b50773bcc360b48a7a674a53a2c77/rayon-core/src/sleep/mod.rs#L242>
      88           10 :         let counts = self.counters.load(Ordering::SeqCst);
      89           10 :         let num_awake_but_idle = (counts >> 8) & 0xff;
      90           10 :         let num_sleepers = counts & 0xff;
      91           10 : 
      92           10 :         // If the queue is non-empty, then we always wake up a worker
      93           10 :         // -- clearly the existing idle jobs aren't enough. Otherwise,
      94           10 :         // check to see if we have enough idle workers.
      95           10 :         if !queue_was_empty || num_awake_but_idle == 0 {
      96           10 :             let num_to_wake = Ord::min(1, num_sleepers);
      97           10 :             self.wake_any_threads(num_to_wake);
      98           10 :         }
      99              : 
     100           10 :         rx
     101           10 :     }
     102              : 
     103              :     #[cold]
     104           10 :     fn wake_any_threads(&self, mut num_to_wake: u64) {
     105           10 :         if num_to_wake > 0 {
     106           10 :             for i in 0..self.parkers.len() {
     107           10 :                 if self.wake_specific_thread(i) {
     108           10 :                     num_to_wake -= 1;
     109           10 :                     if num_to_wake == 0 {
     110           10 :                         return;
     111            0 :                     }
     112            0 :                 }
     113              :             }
     114            0 :         }
     115           10 :     }
     116              : 
     117           10 :     fn wake_specific_thread(&self, index: usize) -> bool {
     118           10 :         let (condvar, lock) = &self.parkers[index];
     119           10 : 
     120           10 :         let mut state = lock.lock();
     121           10 :         if *state == ThreadState::Parked {
     122           10 :             condvar.notify_one();
     123           10 : 
     124           10 :             // When the thread went to sleep, it will have incremented
     125           10 :             // this value. When we wake it, its our job to decrement
     126           10 :             // it. We could have the thread do it, but that would
     127           10 :             // introduce a delay between when the thread was
     128           10 :             // *notified* and when this counter was decremented. That
     129           10 :             // might mislead people with new work into thinking that
     130           10 :             // there are sleeping threads that they should try to
     131           10 :             // wake, when in fact there is nothing left for them to
     132           10 :             // do.
     133           10 :             self.counters.fetch_sub(1, Ordering::SeqCst);
     134           10 :             *state = ThreadState::Active;
     135           10 : 
     136           10 :             true
     137              :         } else {
     138            0 :             false
     139              :         }
     140           10 :     }
     141              : 
     142           19 :     fn steal(&self, rng: &mut impl Rng, skip: usize, worker: &Worker<JobSpec>) -> Option<JobSpec> {
     143           19 :         // announce thread as idle
     144           19 :         self.counters.fetch_add(256, Ordering::SeqCst);
     145              : 
     146              :         // try steal from the global queue
     147           19 :         loop {
     148           19 :             match self.queue.steal_batch_and_pop(worker) {
     149           10 :                 crossbeam_deque::Steal::Success(job) => {
     150           10 :                     self.metrics
     151           10 :                         .injector_queue_depth
     152           10 :                         .set(self.queue.len() as i64);
     153           10 :                     // no longer idle
     154           10 :                     self.counters.fetch_sub(256, Ordering::SeqCst);
     155           10 :                     return Some(job);
     156              :                 }
     157            0 :                 crossbeam_deque::Steal::Retry => continue,
     158            9 :                 crossbeam_deque::Steal::Empty => break,
     159              :             }
     160              :         }
     161              : 
     162              :         // try steal from our neighbours
     163            9 :         loop {
     164            9 :             let mut retry = false;
     165            9 :             let start = rng.gen_range(0..self.stealers.len());
     166            9 :             let job = (start..self.stealers.len())
     167            9 :                 .chain(0..start)
     168            9 :                 .filter(|i| *i != skip)
     169            9 :                 .find_map(
     170            9 :                     |victim| match self.stealers[victim].steal_batch_and_pop(worker) {
     171            0 :                         crossbeam_deque::Steal::Success(job) => Some(job),
     172            0 :                         crossbeam_deque::Steal::Empty => None,
     173              :                         crossbeam_deque::Steal::Retry => {
     174            0 :                             retry = true;
     175            0 :                             None
     176              :                         }
     177            9 :                     },
     178            9 :                 );
     179            9 :             if job.is_some() {
     180              :                 // no longer idle
     181            0 :                 self.counters.fetch_sub(256, Ordering::SeqCst);
     182            0 :                 return job;
     183            9 :             }
     184            9 :             if !retry {
     185            9 :                 return None;
     186            0 :             }
     187              :         }
     188           19 :     }
     189              : }
     190              : 
     191           12 : fn thread_rt(pool: Arc<ThreadPool>, worker: Worker<JobSpec>, index: usize) {
     192           12 :     /// interval when we should steal from the global queue
     193           12 :     /// so that tail latencies are managed appropriately
     194           12 :     const STEAL_INTERVAL: usize = 61;
     195           12 : 
     196           12 :     /// How often to reset the sketch values
     197           12 :     const SKETCH_RESET_INTERVAL: usize = 1021;
     198           12 : 
     199           12 :     let mut rng = SmallRng::from_entropy();
     200           12 : 
     201           12 :     // used to determine whether we should temporarily skip tasks for fairness.
     202           12 :     // 99% of estimates will overcount by no more than 4096 samples
     203           12 :     let mut sketch = CountMinSketch::with_params(1.0 / (SKETCH_RESET_INTERVAL as f64), 0.01);
     204           12 : 
     205           12 :     let (condvar, lock) = &pool.parkers[index];
     206              : 
     207           21 :     'wait: loop {
     208           21 :         // wait for notification of work
     209           21 :         {
     210           21 :             let mut lock = lock.lock();
     211           21 : 
     212           21 :             // queue is empty
     213           21 :             pool.metrics
     214           21 :                 .worker_queue_depth
     215           21 :                 .set(ThreadPoolWorkerId(index), 0);
     216           21 : 
     217           21 :             // subtract 1 from idle count, add 1 to sleeping count.
     218           21 :             pool.counters.fetch_sub(255, Ordering::SeqCst);
     219           21 : 
     220           21 :             *lock = ThreadState::Parked;
     221           21 :             condvar.wait(&mut lock);
     222           21 :         }
     223              : 
     224           32 :         for i in 0.. {
     225           32 :             let mut job = match worker
     226           32 :                 .pop()
     227           32 :                 .or_else(|| pool.steal(&mut rng, index, &worker))
     228              :             {
     229           23 :                 Some(job) => job,
     230            9 :                 None => continue 'wait,
     231              :             };
     232              : 
     233           23 :             pool.metrics
     234           23 :                 .worker_queue_depth
     235           23 :                 .set(ThreadPoolWorkerId(index), worker.len() as i64);
     236           23 : 
     237           23 :             // receiver is closed, cancel the task
     238           23 :             if !job.response.is_closed() {
     239           23 :                 let rate = sketch.inc_and_return(&job.endpoint, job.pbkdf2.cost());
     240           23 : 
     241           23 :                 const P: f64 = 2000.0;
     242           23 :                 // probability decreases as rate increases.
     243           23 :                 // lower probability, higher chance of being skipped
     244           23 :                 //
     245           23 :                 // estimates (rate in terms of 4096 rounds):
     246           23 :                 // rate = 0    => probability = 100%
     247           23 :                 // rate = 10   => probability = 71.3%
     248           23 :                 // rate = 50   => probability = 62.1%
     249           23 :                 // rate = 500  => probability = 52.3%
     250           23 :                 // rate = 1021 => probability = 49.8%
     251           23 :                 //
     252           23 :                 // My expectation is that the pool queue will only begin backing up at ~1000rps
     253           23 :                 // in which case the SKETCH_RESET_INTERVAL represents 1 second. Thus, the rates above
     254           23 :                 // are in requests per second.
     255           23 :                 let probability = P.ln() / (P + rate as f64).ln();
     256           23 :                 if pool.queue.len() > 32 || rng.gen_bool(probability) {
     257           21 :                     pool.metrics
     258           21 :                         .worker_task_turns_total
     259           21 :                         .inc(ThreadPoolWorkerId(index));
     260           21 : 
     261           21 :                     match job.pbkdf2.turn() {
     262           10 :                         std::task::Poll::Ready(result) => {
     263           10 :                             let _ = job.response.send(result);
     264           10 :                         }
     265            0 :                         std::task::Poll::Pending => worker.push(job),
     266              :                     }
     267              :                 } else {
     268            2 :                     pool.metrics
     269            2 :                         .worker_task_skips_total
     270            2 :                         .inc(ThreadPoolWorkerId(index));
     271            2 : 
     272            2 :                     // skip for now
     273            2 :                     worker.push(job)
     274              :                 }
     275            0 :             }
     276              : 
     277              :             // if we get stuck with a few long lived jobs in the queue
     278              :             // it's better to try and steal from the queue too for fairness
     279           12 :             if i % STEAL_INTERVAL == 0 {
     280            9 :                 let _ = pool.queue.steal_batch(&worker);
     281            9 :             }
     282              : 
     283           12 :             if i % SKETCH_RESET_INTERVAL == 0 {
     284            9 :                 sketch.reset();
     285            9 :             }
     286              :         }
     287              :     }
     288              : }
     289              : 
     290              : struct JobSpec {
     291              :     response: oneshot::Sender<[u8; 32]>,
     292              :     pbkdf2: Pbkdf2,
     293              :     endpoint: EndpointIdInt,
     294              : }
     295              : 
     296              : #[cfg(test)]
     297              : mod tests {
     298              :     use crate::EndpointId;
     299              : 
     300              :     use super::*;
     301              : 
     302              :     #[tokio::test]
     303            2 :     async fn hash_is_correct() {
     304            2 :         let pool = ThreadPool::new(1);
     305            2 : 
     306            2 :         let ep = EndpointId::from("foo");
     307            2 :         let ep = EndpointIdInt::from(ep);
     308            2 : 
     309            2 :         let salt = [0x55; 32];
     310            2 :         let actual = pool
     311            2 :             .spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096))
     312            2 :             .await
     313            2 :             .unwrap();
     314            2 : 
     315            2 :         let expected = [
     316            2 :             10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
     317            2 :             178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
     318            2 :         ];
     319            2 :         assert_eq!(actual, expected)
     320            2 :     }
     321              : }
        

Generated by: LCOV version 2.1-beta