LCOV - code coverage report
Current view: top level - proxy/src/scram - threadpool.rs (source / functions) Coverage Total Hit
Test: b9d67f908f91f00e353a27440ba89f642a869959.info Lines: 92.6 % 108 100
Test Date: 2024-11-19 21:44:13 Functions: 100.0 % 13 13

            Line data    Source code
       1              : //! Custom threadpool implementation for password hashing.
       2              : //!
       3              : //! Requirements:
       4              : //! 1. Fairness per endpoint.
       5              : //! 2. Yield support for high iteration counts.
       6              : 
       7              : use std::cell::RefCell;
       8              : use std::future::Future;
       9              : use std::pin::Pin;
      10              : use std::sync::atomic::{AtomicUsize, Ordering};
      11              : use std::sync::{Arc, Weak};
      12              : use std::task::{Context, Poll};
      13              : 
      14              : use futures::FutureExt;
      15              : use rand::rngs::SmallRng;
      16              : use rand::{Rng, SeedableRng};
      17              : 
      18              : use super::pbkdf2::Pbkdf2;
      19              : use crate::intern::EndpointIdInt;
      20              : use crate::metrics::{ThreadPoolMetrics, ThreadPoolWorkerId};
      21              : use crate::scram::countmin::CountMinSketch;
      22              : 
      23              : pub struct ThreadPool {
      24              :     runtime: Option<tokio::runtime::Runtime>,
      25              :     pub metrics: Arc<ThreadPoolMetrics>,
      26              : }
      27              : 
      28              : /// How often to reset the sketch values
      29              : const SKETCH_RESET_INTERVAL: u64 = 1021;
      30              : 
      31              : thread_local! {
      32              :     static STATE: RefCell<Option<ThreadRt>> = const { RefCell::new(None) };
      33              : }
      34              : 
      35              : impl ThreadPool {
      36            6 :     pub fn new(n_workers: u8) -> Arc<Self> {
      37            6 :         // rayon would be nice here, but yielding in rayon does not work well afaict.
      38            6 : 
      39            6 :         if n_workers == 0 {
      40            0 :             return Arc::new(Self {
      41            0 :                 runtime: None,
      42            0 :                 metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
      43            0 :             });
      44            6 :         }
      45            6 : 
      46            6 :         Arc::new_cyclic(|pool| {
      47            6 :             let pool = pool.clone();
      48            6 :             let worker_id = AtomicUsize::new(0);
      49            6 : 
      50            6 :             let runtime = tokio::runtime::Builder::new_multi_thread()
      51            6 :                 .worker_threads(n_workers as usize)
      52            6 :                 .on_thread_start(move || {
      53            6 :                     STATE.with_borrow_mut(|state| {
      54            6 :                         *state = Some(ThreadRt {
      55            6 :                             pool: pool.clone(),
      56            6 :                             id: ThreadPoolWorkerId(worker_id.fetch_add(1, Ordering::Relaxed)),
      57            6 :                             rng: SmallRng::from_entropy(),
      58            6 :                             // used to determine whether we should temporarily skip tasks for fairness.
      59            6 :                             // 99% of estimates will overcount by no more than 4096 samples
      60            6 :                             countmin: CountMinSketch::with_params(
      61            6 :                                 1.0 / (SKETCH_RESET_INTERVAL as f64),
      62            6 :                                 0.01,
      63            6 :                             ),
      64            6 :                             tick: 0,
      65            6 :                         });
      66            6 :                     });
      67            6 :                 })
      68            6 :                 .build()
      69            6 :                 .unwrap();
      70            6 : 
      71            6 :             Self {
      72            6 :                 runtime: Some(runtime),
      73            6 :                 metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
      74            6 :             }
      75            6 :         })
      76            6 :     }
      77              : 
      78            5 :     pub(crate) fn spawn_job(&self, endpoint: EndpointIdInt, pbkdf2: Pbkdf2) -> JobHandle {
      79            5 :         JobHandle(
      80            5 :             self.runtime
      81            5 :                 .as_ref()
      82            5 :                 .unwrap()
      83            5 :                 .spawn(JobSpec { pbkdf2, endpoint }),
      84            5 :         )
      85            5 :     }
      86              : }
      87              : 
      88              : impl Drop for ThreadPool {
      89            3 :     fn drop(&mut self) {
      90            3 :         self.runtime.take().unwrap().shutdown_background();
      91            3 :     }
      92              : }
      93              : 
      94              : struct ThreadRt {
      95              :     pool: Weak<ThreadPool>,
      96              :     id: ThreadPoolWorkerId,
      97              :     rng: SmallRng,
      98              :     countmin: CountMinSketch,
      99              :     tick: u64,
     100              : }
     101              : 
     102              : impl ThreadRt {
     103            6 :     fn should_run(&mut self, job: &JobSpec) -> bool {
     104            6 :         let rate = self
     105            6 :             .countmin
     106            6 :             .inc_and_return(&job.endpoint, job.pbkdf2.cost());
     107              : 
     108              :         const P: f64 = 2000.0;
     109              :         // probability decreases as rate increases.
     110              :         // lower probability, higher chance of being skipped
     111              :         //
     112              :         // estimates (rate in terms of 4096 rounds):
     113              :         // rate = 0    => probability = 100%
     114              :         // rate = 10   => probability = 71.3%
     115              :         // rate = 50   => probability = 62.1%
     116              :         // rate = 500  => probability = 52.3%
     117              :         // rate = 1021 => probability = 49.8%
     118              :         //
     119              :         // My expectation is that the pool queue will only begin backing up at ~1000rps
     120              :         // in which case the SKETCH_RESET_INTERVAL represents 1 second. Thus, the rates above
     121              :         // are in requests per second.
     122            6 :         let probability = P.ln() / (P + rate as f64).ln();
     123            6 :         self.rng.gen_bool(probability)
     124            6 :     }
     125              : }
     126              : 
     127              : struct JobSpec {
     128              :     pbkdf2: Pbkdf2,
     129              :     endpoint: EndpointIdInt,
     130              : }
     131              : 
     132              : impl Future for JobSpec {
     133              :     type Output = [u8; 32];
     134              : 
     135            6 :     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
     136            6 :         STATE.with_borrow_mut(|state| {
     137            6 :             let state = state.as_mut().expect("should be set on thread startup");
     138            6 : 
     139            6 :             state.tick = state.tick.wrapping_add(1);
     140            6 :             if state.tick % SKETCH_RESET_INTERVAL == 0 {
     141            0 :                 state.countmin.reset();
     142            6 :             }
     143              : 
     144            6 :             if state.should_run(&self) {
     145            5 :                 if let Some(pool) = state.pool.upgrade() {
     146            5 :                     pool.metrics.worker_task_turns_total.inc(state.id);
     147            5 :                 }
     148              : 
     149            5 :                 match self.pbkdf2.turn() {
     150            5 :                     Poll::Ready(result) => Poll::Ready(result),
     151              :                     // more to do, we shall requeue
     152              :                     Poll::Pending => {
     153            0 :                         cx.waker().wake_by_ref();
     154            0 :                         Poll::Pending
     155              :                     }
     156              :                 }
     157              :             } else {
     158            1 :                 if let Some(pool) = state.pool.upgrade() {
     159            1 :                     pool.metrics.worker_task_skips_total.inc(state.id);
     160            1 :                 }
     161              : 
     162            1 :                 cx.waker().wake_by_ref();
     163            1 :                 Poll::Pending
     164              :             }
     165            6 :         })
     166            6 :     }
     167              : }
     168              : 
     169              : pub(crate) struct JobHandle(tokio::task::JoinHandle<[u8; 32]>);
     170              : 
     171              : impl Future for JobHandle {
     172              :     type Output = [u8; 32];
     173              : 
     174           10 :     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
     175           10 :         match self.0.poll_unpin(cx) {
     176            5 :             Poll::Ready(Ok(ok)) => Poll::Ready(ok),
     177            0 :             Poll::Ready(Err(err)) => std::panic::resume_unwind(err.into_panic()),
     178            5 :             Poll::Pending => Poll::Pending,
     179              :         }
     180           10 :     }
     181              : }
     182              : 
     183              : impl Drop for JobHandle {
     184            5 :     fn drop(&mut self) {
     185            5 :         self.0.abort();
     186            5 :     }
     187              : }
     188              : 
     189              : #[cfg(test)]
     190              : mod tests {
     191              :     use super::*;
     192              :     use crate::types::EndpointId;
     193              : 
     194              :     #[tokio::test]
     195            1 :     async fn hash_is_correct() {
     196            1 :         let pool = ThreadPool::new(1);
     197            1 : 
     198            1 :         let ep = EndpointId::from("foo");
     199            1 :         let ep = EndpointIdInt::from(ep);
     200            1 : 
     201            1 :         let salt = [0x55; 32];
     202            1 :         let actual = pool
     203            1 :             .spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096))
     204            1 :             .await;
     205            1 : 
     206            1 :         let expected = [
     207            1 :             10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
     208            1 :             178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
     209            1 :         ];
     210            1 :         assert_eq!(actual, expected);
     211            1 :     }
     212              : }
        

Generated by: LCOV version 2.1-beta