Line data Source code
1 : //! Custom threadpool implementation for password hashing.
2 : //!
3 : //! Requirements:
4 : //! 1. Fairness per endpoint.
5 : //! 2. Yield support for high iteration counts.
6 :
7 : use std::sync::{
8 : atomic::{AtomicU64, Ordering},
9 : Arc,
10 : };
11 :
12 : use crossbeam_deque::{Injector, Stealer, Worker};
13 : use itertools::Itertools;
14 : use parking_lot::{Condvar, Mutex};
15 : use rand::Rng;
16 : use rand::{rngs::SmallRng, SeedableRng};
17 : use tokio::sync::oneshot;
18 :
19 : use crate::{
20 : intern::EndpointIdInt,
21 : metrics::{ThreadPoolMetrics, ThreadPoolWorkerId},
22 : scram::countmin::CountMinSketch,
23 : };
24 :
25 : use super::pbkdf2::Pbkdf2;
26 :
27 : pub struct ThreadPool {
28 : queue: Injector<JobSpec>,
29 : stealers: Vec<Stealer<JobSpec>>,
30 : parkers: Vec<(Condvar, Mutex<ThreadState>)>,
31 : /// bitpacked representation.
32 : /// lower 8 bits = number of sleeping threads
33 : /// next 8 bits = number of idle threads (searching for work)
34 : counters: AtomicU64,
35 :
36 : pub metrics: Arc<ThreadPoolMetrics>,
37 : }
38 :
39 : #[derive(PartialEq)]
40 : enum ThreadState {
41 : Parked,
42 : Active,
43 : }
44 :
45 : impl ThreadPool {
46 12 : pub fn new(n_workers: u8) -> Arc<Self> {
47 12 : let workers = (0..n_workers).map(|_| Worker::new_fifo()).collect_vec();
48 12 : let stealers = workers.iter().map(|w| w.stealer()).collect_vec();
49 12 :
50 12 : let parkers = (0..n_workers)
51 12 : .map(|_| (Condvar::new(), Mutex::new(ThreadState::Active)))
52 12 : .collect_vec();
53 12 :
54 12 : let pool = Arc::new(Self {
55 12 : queue: Injector::new(),
56 12 : stealers,
57 12 : parkers,
58 12 : // threads start searching for work
59 12 : counters: AtomicU64::new((n_workers as u64) << 8),
60 12 : metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
61 12 : });
62 :
63 12 : for (i, worker) in workers.into_iter().enumerate() {
64 12 : let pool = Arc::clone(&pool);
65 12 : std::thread::spawn(move || thread_rt(pool, worker, i));
66 12 : }
67 :
68 12 : pool
69 12 : }
70 :
71 10 : pub fn spawn_job(
72 10 : &self,
73 10 : endpoint: EndpointIdInt,
74 10 : pbkdf2: Pbkdf2,
75 10 : ) -> oneshot::Receiver<[u8; 32]> {
76 10 : let (tx, rx) = oneshot::channel();
77 10 :
78 10 : let queue_was_empty = self.queue.is_empty();
79 10 :
80 10 : self.metrics.injector_queue_depth.inc();
81 10 : self.queue.push(JobSpec {
82 10 : response: tx,
83 10 : pbkdf2,
84 10 : endpoint,
85 10 : });
86 10 :
87 10 : // inspired from <https://github.com/rayon-rs/rayon/blob/3e3962cb8f7b50773bcc360b48a7a674a53a2c77/rayon-core/src/sleep/mod.rs#L242>
88 10 : let counts = self.counters.load(Ordering::SeqCst);
89 10 : let num_awake_but_idle = (counts >> 8) & 0xff;
90 10 : let num_sleepers = counts & 0xff;
91 10 :
92 10 : // If the queue is non-empty, then we always wake up a worker
93 10 : // -- clearly the existing idle jobs aren't enough. Otherwise,
94 10 : // check to see if we have enough idle workers.
95 10 : if !queue_was_empty || num_awake_but_idle == 0 {
96 10 : let num_to_wake = Ord::min(1, num_sleepers);
97 10 : self.wake_any_threads(num_to_wake);
98 10 : }
99 :
100 10 : rx
101 10 : }
102 :
103 : #[cold]
104 10 : fn wake_any_threads(&self, mut num_to_wake: u64) {
105 10 : if num_to_wake > 0 {
106 10 : for i in 0..self.parkers.len() {
107 10 : if self.wake_specific_thread(i) {
108 10 : num_to_wake -= 1;
109 10 : if num_to_wake == 0 {
110 10 : return;
111 0 : }
112 0 : }
113 : }
114 0 : }
115 10 : }
116 :
117 10 : fn wake_specific_thread(&self, index: usize) -> bool {
118 10 : let (condvar, lock) = &self.parkers[index];
119 10 :
120 10 : let mut state = lock.lock();
121 10 : if *state == ThreadState::Parked {
122 10 : condvar.notify_one();
123 10 :
124 10 : // When the thread went to sleep, it will have incremented
125 10 : // this value. When we wake it, its our job to decrement
126 10 : // it. We could have the thread do it, but that would
127 10 : // introduce a delay between when the thread was
128 10 : // *notified* and when this counter was decremented. That
129 10 : // might mislead people with new work into thinking that
130 10 : // there are sleeping threads that they should try to
131 10 : // wake, when in fact there is nothing left for them to
132 10 : // do.
133 10 : self.counters.fetch_sub(1, Ordering::SeqCst);
134 10 : *state = ThreadState::Active;
135 10 :
136 10 : true
137 : } else {
138 0 : false
139 : }
140 10 : }
141 :
142 19 : fn steal(&self, rng: &mut impl Rng, skip: usize, worker: &Worker<JobSpec>) -> Option<JobSpec> {
143 19 : // announce thread as idle
144 19 : self.counters.fetch_add(256, Ordering::SeqCst);
145 :
146 : // try steal from the global queue
147 19 : loop {
148 19 : match self.queue.steal_batch_and_pop(worker) {
149 10 : crossbeam_deque::Steal::Success(job) => {
150 10 : self.metrics
151 10 : .injector_queue_depth
152 10 : .set(self.queue.len() as i64);
153 10 : // no longer idle
154 10 : self.counters.fetch_sub(256, Ordering::SeqCst);
155 10 : return Some(job);
156 : }
157 0 : crossbeam_deque::Steal::Retry => continue,
158 9 : crossbeam_deque::Steal::Empty => break,
159 : }
160 : }
161 :
162 : // try steal from our neighbours
163 9 : loop {
164 9 : let mut retry = false;
165 9 : let start = rng.gen_range(0..self.stealers.len());
166 9 : let job = (start..self.stealers.len())
167 9 : .chain(0..start)
168 9 : .filter(|i| *i != skip)
169 9 : .find_map(
170 9 : |victim| match self.stealers[victim].steal_batch_and_pop(worker) {
171 0 : crossbeam_deque::Steal::Success(job) => Some(job),
172 0 : crossbeam_deque::Steal::Empty => None,
173 : crossbeam_deque::Steal::Retry => {
174 0 : retry = true;
175 0 : None
176 : }
177 9 : },
178 9 : );
179 9 : if job.is_some() {
180 : // no longer idle
181 0 : self.counters.fetch_sub(256, Ordering::SeqCst);
182 0 : return job;
183 9 : }
184 9 : if !retry {
185 9 : return None;
186 0 : }
187 : }
188 19 : }
189 : }
190 :
191 12 : fn thread_rt(pool: Arc<ThreadPool>, worker: Worker<JobSpec>, index: usize) {
192 12 : /// interval when we should steal from the global queue
193 12 : /// so that tail latencies are managed appropriately
194 12 : const STEAL_INTERVAL: usize = 61;
195 12 :
196 12 : /// How often to reset the sketch values
197 12 : const SKETCH_RESET_INTERVAL: usize = 1021;
198 12 :
199 12 : let mut rng = SmallRng::from_entropy();
200 12 :
201 12 : // used to determine whether we should temporarily skip tasks for fairness.
202 12 : // 99% of estimates will overcount by no more than 4096 samples
203 12 : let mut sketch = CountMinSketch::with_params(1.0 / (SKETCH_RESET_INTERVAL as f64), 0.01);
204 12 :
205 12 : let (condvar, lock) = &pool.parkers[index];
206 :
207 21 : 'wait: loop {
208 21 : // wait for notification of work
209 21 : {
210 21 : let mut lock = lock.lock();
211 21 :
212 21 : // queue is empty
213 21 : pool.metrics
214 21 : .worker_queue_depth
215 21 : .set(ThreadPoolWorkerId(index), 0);
216 21 :
217 21 : // subtract 1 from idle count, add 1 to sleeping count.
218 21 : pool.counters.fetch_sub(255, Ordering::SeqCst);
219 21 :
220 21 : *lock = ThreadState::Parked;
221 21 : condvar.wait(&mut lock);
222 21 : }
223 :
224 31 : for i in 0.. {
225 31 : let mut job = match worker
226 31 : .pop()
227 31 : .or_else(|| pool.steal(&mut rng, index, &worker))
228 : {
229 22 : Some(job) => job,
230 9 : None => continue 'wait,
231 : };
232 :
233 22 : pool.metrics
234 22 : .worker_queue_depth
235 22 : .set(ThreadPoolWorkerId(index), worker.len() as i64);
236 22 :
237 22 : // receiver is closed, cancel the task
238 22 : if !job.response.is_closed() {
239 22 : let rate = sketch.inc_and_return(&job.endpoint, job.pbkdf2.cost());
240 22 :
241 22 : const P: f64 = 2000.0;
242 22 : // probability decreases as rate increases.
243 22 : // lower probability, higher chance of being skipped
244 22 : //
245 22 : // estimates (rate in terms of 4096 rounds):
246 22 : // rate = 0 => probability = 100%
247 22 : // rate = 10 => probability = 71.3%
248 22 : // rate = 50 => probability = 62.1%
249 22 : // rate = 500 => probability = 52.3%
250 22 : // rate = 1021 => probability = 49.8%
251 22 : //
252 22 : // My expectation is that the pool queue will only begin backing up at ~1000rps
253 22 : // in which case the SKETCH_RESET_INTERVAL represents 1 second. Thus, the rates above
254 22 : // are in requests per second.
255 22 : let probability = P.ln() / (P + rate as f64).ln();
256 22 : if pool.queue.len() > 32 || rng.gen_bool(probability) {
257 21 : pool.metrics
258 21 : .worker_task_turns_total
259 21 : .inc(ThreadPoolWorkerId(index));
260 21 :
261 21 : match job.pbkdf2.turn() {
262 10 : std::task::Poll::Ready(result) => {
263 10 : let _ = job.response.send(result);
264 10 : }
265 0 : std::task::Poll::Pending => worker.push(job),
266 : }
267 : } else {
268 1 : pool.metrics
269 1 : .worker_task_skips_total
270 1 : .inc(ThreadPoolWorkerId(index));
271 1 :
272 1 : // skip for now
273 1 : worker.push(job)
274 : }
275 0 : }
276 :
277 : // if we get stuck with a few long lived jobs in the queue
278 : // it's better to try and steal from the queue too for fairness
279 11 : if i % STEAL_INTERVAL == 0 {
280 9 : let _ = pool.queue.steal_batch(&worker);
281 9 : }
282 :
283 11 : if i % SKETCH_RESET_INTERVAL == 0 {
284 9 : sketch.reset();
285 9 : }
286 : }
287 : }
288 : }
289 :
290 : struct JobSpec {
291 : response: oneshot::Sender<[u8; 32]>,
292 : pbkdf2: Pbkdf2,
293 : endpoint: EndpointIdInt,
294 : }
295 :
296 : #[cfg(test)]
297 : mod tests {
298 : use crate::EndpointId;
299 :
300 : use super::*;
301 :
302 : #[tokio::test]
303 2 : async fn hash_is_correct() {
304 2 : let pool = ThreadPool::new(1);
305 2 :
306 2 : let ep = EndpointId::from("foo");
307 2 : let ep = EndpointIdInt::from(ep);
308 2 :
309 2 : let salt = [0x55; 32];
310 2 : let actual = pool
311 2 : .spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096))
312 2 : .await
313 2 : .unwrap();
314 2 :
315 2 : let expected = [
316 2 : 10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
317 2 : 178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
318 2 : ];
319 2 : assert_eq!(actual, expected)
320 2 : }
321 : }
|