|             Line data    Source code 
       1              : //! Custom threadpool implementation for password hashing.
       2              : //!
       3              : //! Requirements:
       4              : //! 1. Fairness per endpoint.
       5              : //! 2. Yield support for high iteration counts.
       6              : 
       7              : use std::cell::RefCell;
       8              : use std::future::Future;
       9              : use std::pin::Pin;
      10              : use std::sync::atomic::{AtomicUsize, Ordering};
      11              : use std::sync::{Arc, Weak};
      12              : use std::task::{Context, Poll};
      13              : 
      14              : use futures::FutureExt;
      15              : use rand::rngs::SmallRng;
      16              : use rand::{Rng, SeedableRng};
      17              : 
      18              : use super::cache::Pbkdf2Cache;
      19              : use super::pbkdf2;
      20              : use super::pbkdf2::Pbkdf2;
      21              : use crate::intern::EndpointIdInt;
      22              : use crate::metrics::{ThreadPoolMetrics, ThreadPoolWorkerId};
      23              : use crate::scram::countmin::CountMinSketch;
      24              : 
      25              : pub struct ThreadPool {
      26              :     runtime: Option<tokio::runtime::Runtime>,
      27              :     pub metrics: Arc<ThreadPoolMetrics>,
      28              : 
      29              :     // we hash a lot of passwords.
      30              :     // we keep a cache of partial hashes for faster validation.
      31              :     pub(super) cache: Pbkdf2Cache,
      32              : }
      33              : 
      34              : /// How often to reset the sketch values
      35              : const SKETCH_RESET_INTERVAL: u64 = 1021;
      36              : 
      37              : thread_local! {
      38              :     static STATE: RefCell<Option<ThreadRt>> = const { RefCell::new(None) };
      39              : }
      40              : 
      41              : impl ThreadPool {
      42            7 :     pub fn new(mut n_workers: u8) -> Arc<Self> {
      43              :         // rayon would be nice here, but yielding in rayon does not work well afaict.
      44              : 
      45            7 :         if n_workers == 0 {
      46            0 :             n_workers = 1;
      47            7 :         }
      48              : 
      49            7 :         Arc::new_cyclic(|pool| {
      50            7 :             let pool = pool.clone();
      51            7 :             let worker_id = AtomicUsize::new(0);
      52              : 
      53            7 :             let runtime = tokio::runtime::Builder::new_multi_thread()
      54            7 :                 .worker_threads(n_workers as usize)
      55            7 :                 .on_thread_start(move || {
      56            7 :                     STATE.with_borrow_mut(|state| {
      57            7 :                         *state = Some(ThreadRt {
      58            7 :                             pool: pool.clone(),
      59            7 :                             id: ThreadPoolWorkerId(worker_id.fetch_add(1, Ordering::Relaxed)),
      60            7 :                             rng: SmallRng::from_os_rng(),
      61            7 :                             // used to determine whether we should temporarily skip tasks for fairness.
      62            7 :                             // 99% of estimates will overcount by no more than 4096 samples
      63            7 :                             countmin: CountMinSketch::with_params(
      64            7 :                                 1.0 / (SKETCH_RESET_INTERVAL as f64),
      65            7 :                                 0.01,
      66            7 :                             ),
      67            7 :                             tick: 0,
      68            7 :                         });
      69            7 :                     });
      70            7 :                 })
      71            7 :                 .build()
      72            7 :                 .expect("password threadpool runtime should be configured correctly");
      73              : 
      74            7 :             Self {
      75            7 :                 runtime: Some(runtime),
      76            7 :                 metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
      77            7 :                 cache: Pbkdf2Cache::new(),
      78            7 :             }
      79            7 :         })
      80            7 :     }
      81              : 
      82           15 :     pub(crate) fn spawn_job(&self, endpoint: EndpointIdInt, pbkdf2: Pbkdf2) -> JobHandle {
      83           15 :         JobHandle(
      84           15 :             self.runtime
      85           15 :                 .as_ref()
      86           15 :                 .expect("runtime is always set")
      87           15 :                 .spawn(JobSpec { pbkdf2, endpoint }),
      88           15 :         )
      89           15 :     }
      90              : }
      91              : 
      92              : impl Drop for ThreadPool {
      93            4 :     fn drop(&mut self) {
      94            4 :         self.runtime
      95            4 :             .take()
      96            4 :             .expect("runtime is always set")
      97            4 :             .shutdown_background();
      98            4 :     }
      99              : }
     100              : 
     101              : struct ThreadRt {
     102              :     pool: Weak<ThreadPool>,
     103              :     id: ThreadPoolWorkerId,
     104              :     rng: SmallRng,
     105              :     countmin: CountMinSketch,
     106              :     tick: u64,
     107              : }
     108              : 
     109              : impl ThreadRt {
     110           20 :     fn should_run(&mut self, job: &JobSpec) -> bool {
     111           20 :         let rate = self
     112           20 :             .countmin
     113           20 :             .inc_and_return(&job.endpoint, job.pbkdf2.cost());
     114              : 
     115              :         const P: f64 = 2000.0;
     116              :         // probability decreases as rate increases.
     117              :         // lower probability, higher chance of being skipped
     118              :         //
     119              :         // estimates (rate in terms of 4096 rounds):
     120              :         // rate = 0    => probability = 100%
     121              :         // rate = 10   => probability = 71.3%
     122              :         // rate = 50   => probability = 62.1%
     123              :         // rate = 500  => probability = 52.3%
     124              :         // rate = 1021 => probability = 49.8%
     125              :         //
     126              :         // My expectation is that the pool queue will only begin backing up at ~1000rps
     127              :         // in which case the SKETCH_RESET_INTERVAL represents 1 second. Thus, the rates above
     128              :         // are in requests per second.
     129           20 :         let probability = P.ln() / (P + rate as f64).ln();
     130           20 :         self.rng.random_bool(probability)
     131           20 :     }
     132              : }
     133              : 
     134              : struct JobSpec {
     135              :     pbkdf2: Pbkdf2,
     136              :     endpoint: EndpointIdInt,
     137              : }
     138              : 
     139              : impl Future for JobSpec {
     140              :     type Output = pbkdf2::Block;
     141              : 
     142           20 :     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
     143           20 :         STATE.with_borrow_mut(|state| {
     144           20 :             let state = state.as_mut().expect("should be set on thread startup");
     145              : 
     146           20 :             state.tick = state.tick.wrapping_add(1);
     147           20 :             if state.tick.is_multiple_of(SKETCH_RESET_INTERVAL) {
     148            0 :                 state.countmin.reset();
     149           20 :             }
     150              : 
     151           20 :             if state.should_run(&self) {
     152           15 :                 if let Some(pool) = state.pool.upgrade() {
     153           15 :                     pool.metrics.worker_task_turns_total.inc(state.id);
     154           15 :                 }
     155              : 
     156           15 :                 match self.pbkdf2.turn() {
     157           15 :                     Poll::Ready(result) => Poll::Ready(result),
     158              :                     // more to do, we shall requeue
     159              :                     Poll::Pending => {
     160            0 :                         cx.waker().wake_by_ref();
     161            0 :                         Poll::Pending
     162              :                     }
     163              :                 }
     164              :             } else {
     165            5 :                 if let Some(pool) = state.pool.upgrade() {
     166            5 :                     pool.metrics.worker_task_skips_total.inc(state.id);
     167            5 :                 }
     168              : 
     169            5 :                 cx.waker().wake_by_ref();
     170            5 :                 Poll::Pending
     171              :             }
     172           20 :         })
     173           20 :     }
     174              : }
     175              : 
     176              : pub(crate) struct JobHandle(tokio::task::JoinHandle<pbkdf2::Block>);
     177              : 
     178              : impl Future for JobHandle {
     179              :     type Output = pbkdf2::Block;
     180              : 
     181           30 :     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
     182           30 :         match self.0.poll_unpin(cx) {
     183           15 :             Poll::Ready(Ok(ok)) => Poll::Ready(ok),
     184            0 :             Poll::Ready(Err(err)) => std::panic::resume_unwind(err.into_panic()),
     185           15 :             Poll::Pending => Poll::Pending,
     186              :         }
     187           30 :     }
     188              : }
     189              : 
     190              : impl Drop for JobHandle {
     191           15 :     fn drop(&mut self) {
     192           15 :         self.0.abort();
     193           15 :     }
     194              : }
     195              : 
     196              : #[cfg(test)]
     197              : mod tests {
     198              :     use super::*;
     199              :     use crate::types::EndpointId;
     200              : 
     201              :     #[tokio::test]
     202            1 :     async fn hash_is_correct() {
     203            1 :         let pool = ThreadPool::new(1);
     204              : 
     205            1 :         let ep = EndpointId::from("foo");
     206            1 :         let ep = EndpointIdInt::from(ep);
     207              : 
     208            1 :         let salt = [0x55; 32];
     209            1 :         let actual = pool
     210            1 :             .spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096))
     211            1 :             .await;
     212              : 
     213            1 :         let expected = &[
     214            1 :             10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
     215            1 :             178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
     216            1 :         ];
     217            1 :         assert_eq!(actual.as_slice(), expected);
     218            1 :     }
     219              : }
         |