Line data Source code
1 : //! Custom threadpool implementation for password hashing.
2 : //!
3 : //! Requirements:
4 : //! 1. Fairness per endpoint.
5 : //! 2. Yield support for high iteration counts.
6 :
7 : use std::cell::RefCell;
8 : use std::future::Future;
9 : use std::pin::Pin;
10 : use std::sync::atomic::{AtomicUsize, Ordering};
11 : use std::sync::{Arc, Weak};
12 : use std::task::{Context, Poll};
13 :
14 : use futures::FutureExt;
15 : use rand::rngs::SmallRng;
16 : use rand::{Rng, SeedableRng};
17 :
18 : use super::cache::Pbkdf2Cache;
19 : use super::pbkdf2;
20 : use super::pbkdf2::Pbkdf2;
21 : use crate::intern::EndpointIdInt;
22 : use crate::metrics::{ThreadPoolMetrics, ThreadPoolWorkerId};
23 : use crate::scram::countmin::CountMinSketch;
24 :
25 : pub struct ThreadPool {
26 : runtime: Option<tokio::runtime::Runtime>,
27 : pub metrics: Arc<ThreadPoolMetrics>,
28 :
29 : // we hash a lot of passwords.
30 : // we keep a cache of partial hashes for faster validation.
31 : pub(super) cache: Pbkdf2Cache,
32 : }
33 :
34 : /// How often to reset the sketch values
35 : const SKETCH_RESET_INTERVAL: u64 = 1021;
36 :
37 : thread_local! {
38 : static STATE: RefCell<Option<ThreadRt>> = const { RefCell::new(None) };
39 : }
40 :
41 : impl ThreadPool {
42 7 : pub fn new(mut n_workers: u8) -> Arc<Self> {
43 : // rayon would be nice here, but yielding in rayon does not work well afaict.
44 :
45 7 : if n_workers == 0 {
46 0 : n_workers = 1;
47 7 : }
48 :
49 7 : Arc::new_cyclic(|pool| {
50 7 : let pool = pool.clone();
51 7 : let worker_id = AtomicUsize::new(0);
52 :
53 7 : let runtime = tokio::runtime::Builder::new_multi_thread()
54 7 : .worker_threads(n_workers as usize)
55 7 : .on_thread_start(move || {
56 7 : STATE.with_borrow_mut(|state| {
57 7 : *state = Some(ThreadRt {
58 7 : pool: pool.clone(),
59 7 : id: ThreadPoolWorkerId(worker_id.fetch_add(1, Ordering::Relaxed)),
60 7 : rng: SmallRng::from_os_rng(),
61 7 : // used to determine whether we should temporarily skip tasks for fairness.
62 7 : // 99% of estimates will overcount by no more than 4096 samples
63 7 : countmin: CountMinSketch::with_params(
64 7 : 1.0 / (SKETCH_RESET_INTERVAL as f64),
65 7 : 0.01,
66 7 : ),
67 7 : tick: 0,
68 7 : });
69 7 : });
70 7 : })
71 7 : .build()
72 7 : .expect("password threadpool runtime should be configured correctly");
73 :
74 7 : Self {
75 7 : runtime: Some(runtime),
76 7 : metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
77 7 : cache: Pbkdf2Cache::new(),
78 7 : }
79 7 : })
80 7 : }
81 :
82 15 : pub(crate) fn spawn_job(&self, endpoint: EndpointIdInt, pbkdf2: Pbkdf2) -> JobHandle {
83 15 : JobHandle(
84 15 : self.runtime
85 15 : .as_ref()
86 15 : .expect("runtime is always set")
87 15 : .spawn(JobSpec { pbkdf2, endpoint }),
88 15 : )
89 15 : }
90 : }
91 :
92 : impl Drop for ThreadPool {
93 4 : fn drop(&mut self) {
94 4 : self.runtime
95 4 : .take()
96 4 : .expect("runtime is always set")
97 4 : .shutdown_background();
98 4 : }
99 : }
100 :
101 : struct ThreadRt {
102 : pool: Weak<ThreadPool>,
103 : id: ThreadPoolWorkerId,
104 : rng: SmallRng,
105 : countmin: CountMinSketch,
106 : tick: u64,
107 : }
108 :
109 : impl ThreadRt {
110 16 : fn should_run(&mut self, job: &JobSpec) -> bool {
111 16 : let rate = self
112 16 : .countmin
113 16 : .inc_and_return(&job.endpoint, job.pbkdf2.cost());
114 :
115 : const P: f64 = 2000.0;
116 : // probability decreases as rate increases.
117 : // lower probability, higher chance of being skipped
118 : //
119 : // estimates (rate in terms of 4096 rounds):
120 : // rate = 0 => probability = 100%
121 : // rate = 10 => probability = 71.3%
122 : // rate = 50 => probability = 62.1%
123 : // rate = 500 => probability = 52.3%
124 : // rate = 1021 => probability = 49.8%
125 : //
126 : // My expectation is that the pool queue will only begin backing up at ~1000rps
127 : // in which case the SKETCH_RESET_INTERVAL represents 1 second. Thus, the rates above
128 : // are in requests per second.
129 16 : let probability = P.ln() / (P + rate as f64).ln();
130 16 : self.rng.random_bool(probability)
131 16 : }
132 : }
133 :
134 : struct JobSpec {
135 : pbkdf2: Pbkdf2,
136 : endpoint: EndpointIdInt,
137 : }
138 :
139 : impl Future for JobSpec {
140 : type Output = pbkdf2::Block;
141 :
142 16 : fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
143 16 : STATE.with_borrow_mut(|state| {
144 16 : let state = state.as_mut().expect("should be set on thread startup");
145 :
146 16 : state.tick = state.tick.wrapping_add(1);
147 16 : if state.tick.is_multiple_of(SKETCH_RESET_INTERVAL) {
148 0 : state.countmin.reset();
149 16 : }
150 :
151 16 : if state.should_run(&self) {
152 15 : if let Some(pool) = state.pool.upgrade() {
153 15 : pool.metrics.worker_task_turns_total.inc(state.id);
154 15 : }
155 :
156 15 : match self.pbkdf2.turn() {
157 15 : Poll::Ready(result) => Poll::Ready(result),
158 : // more to do, we shall requeue
159 : Poll::Pending => {
160 0 : cx.waker().wake_by_ref();
161 0 : Poll::Pending
162 : }
163 : }
164 : } else {
165 1 : if let Some(pool) = state.pool.upgrade() {
166 1 : pool.metrics.worker_task_skips_total.inc(state.id);
167 1 : }
168 :
169 1 : cx.waker().wake_by_ref();
170 1 : Poll::Pending
171 : }
172 16 : })
173 16 : }
174 : }
175 :
176 : pub(crate) struct JobHandle(tokio::task::JoinHandle<pbkdf2::Block>);
177 :
178 : impl Future for JobHandle {
179 : type Output = pbkdf2::Block;
180 :
181 30 : fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
182 30 : match self.0.poll_unpin(cx) {
183 15 : Poll::Ready(Ok(ok)) => Poll::Ready(ok),
184 0 : Poll::Ready(Err(err)) => std::panic::resume_unwind(err.into_panic()),
185 15 : Poll::Pending => Poll::Pending,
186 : }
187 30 : }
188 : }
189 :
190 : impl Drop for JobHandle {
191 15 : fn drop(&mut self) {
192 15 : self.0.abort();
193 15 : }
194 : }
195 :
196 : #[cfg(test)]
197 : mod tests {
198 : use super::*;
199 : use crate::types::EndpointId;
200 :
201 : #[tokio::test]
202 1 : async fn hash_is_correct() {
203 1 : let pool = ThreadPool::new(1);
204 :
205 1 : let ep = EndpointId::from("foo");
206 1 : let ep = EndpointIdInt::from(ep);
207 :
208 1 : let salt = [0x55; 32];
209 1 : let actual = pool
210 1 : .spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096))
211 1 : .await;
212 :
213 1 : let expected = &[
214 1 : 10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
215 1 : 178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
216 1 : ];
217 1 : assert_eq!(actual.as_slice(), expected);
218 1 : }
219 : }
|