Line data Source code
1 : //! Custom threadpool implementation for password hashing.
2 : //!
3 : //! Requirements:
4 : //! 1. Fairness per endpoint.
5 : //! 2. Yield support for high iteration counts.
6 :
7 : use std::sync::{
8 : atomic::{AtomicU64, Ordering},
9 : Arc,
10 : };
11 :
12 : use crossbeam_deque::{Injector, Stealer, Worker};
13 : use itertools::Itertools;
14 : use parking_lot::{Condvar, Mutex};
15 : use rand::Rng;
16 : use rand::{rngs::SmallRng, SeedableRng};
17 : use tokio::sync::oneshot;
18 :
19 : use crate::{
20 : intern::EndpointIdInt,
21 : metrics::{ThreadPoolMetrics, ThreadPoolWorkerId},
22 : scram::countmin::CountMinSketch,
23 : };
24 :
25 : use super::pbkdf2::Pbkdf2;
26 :
27 : pub struct ThreadPool {
28 : queue: Injector<JobSpec>,
29 : stealers: Vec<Stealer<JobSpec>>,
30 : parkers: Vec<(Condvar, Mutex<ThreadState>)>,
31 : /// bitpacked representation.
32 : /// lower 8 bits = number of sleeping threads
33 : /// next 8 bits = number of idle threads (searching for work)
34 : counters: AtomicU64,
35 :
36 : pub metrics: Arc<ThreadPoolMetrics>,
37 : }
38 :
39 : #[derive(PartialEq)]
40 : enum ThreadState {
41 : Parked,
42 : Active,
43 : }
44 :
45 : impl ThreadPool {
46 36 : pub fn new(n_workers: u8) -> Arc<Self> {
47 36 : let workers = (0..n_workers).map(|_| Worker::new_fifo()).collect_vec();
48 36 : let stealers = workers.iter().map(|w| w.stealer()).collect_vec();
49 36 :
50 36 : let parkers = (0..n_workers)
51 36 : .map(|_| (Condvar::new(), Mutex::new(ThreadState::Active)))
52 36 : .collect_vec();
53 36 :
54 36 : let pool = Arc::new(Self {
55 36 : queue: Injector::new(),
56 36 : stealers,
57 36 : parkers,
58 36 : // threads start searching for work
59 36 : counters: AtomicU64::new((n_workers as u64) << 8),
60 36 : metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
61 36 : });
62 :
63 36 : for (i, worker) in workers.into_iter().enumerate() {
64 36 : let pool = Arc::clone(&pool);
65 36 : std::thread::spawn(move || thread_rt(pool, worker, i));
66 36 : }
67 :
68 36 : pool
69 36 : }
70 :
71 30 : pub(crate) fn spawn_job(
72 30 : &self,
73 30 : endpoint: EndpointIdInt,
74 30 : pbkdf2: Pbkdf2,
75 30 : ) -> oneshot::Receiver<[u8; 32]> {
76 30 : let (tx, rx) = oneshot::channel();
77 30 :
78 30 : let queue_was_empty = self.queue.is_empty();
79 30 :
80 30 : self.metrics.injector_queue_depth.inc();
81 30 : self.queue.push(JobSpec {
82 30 : response: tx,
83 30 : pbkdf2,
84 30 : endpoint,
85 30 : });
86 30 :
87 30 : // inspired from <https://github.com/rayon-rs/rayon/blob/3e3962cb8f7b50773bcc360b48a7a674a53a2c77/rayon-core/src/sleep/mod.rs#L242>
88 30 : let counts = self.counters.load(Ordering::SeqCst);
89 30 : let num_awake_but_idle = (counts >> 8) & 0xff;
90 30 : let num_sleepers = counts & 0xff;
91 30 :
92 30 : // If the queue is non-empty, then we always wake up a worker
93 30 : // -- clearly the existing idle jobs aren't enough. Otherwise,
94 30 : // check to see if we have enough idle workers.
95 30 : if !queue_was_empty || num_awake_but_idle == 0 {
96 30 : let num_to_wake = Ord::min(1, num_sleepers);
97 30 : self.wake_any_threads(num_to_wake);
98 30 : }
99 :
100 30 : rx
101 30 : }
102 :
103 : #[cold]
104 30 : fn wake_any_threads(&self, mut num_to_wake: u64) {
105 30 : if num_to_wake > 0 {
106 30 : for i in 0..self.parkers.len() {
107 30 : if self.wake_specific_thread(i) {
108 30 : num_to_wake -= 1;
109 30 : if num_to_wake == 0 {
110 30 : return;
111 0 : }
112 0 : }
113 : }
114 0 : }
115 30 : }
116 :
117 30 : fn wake_specific_thread(&self, index: usize) -> bool {
118 30 : let (condvar, lock) = &self.parkers[index];
119 30 :
120 30 : let mut state = lock.lock();
121 30 : if *state == ThreadState::Parked {
122 30 : condvar.notify_one();
123 30 :
124 30 : // When the thread went to sleep, it will have incremented
125 30 : // this value. When we wake it, its our job to decrement
126 30 : // it. We could have the thread do it, but that would
127 30 : // introduce a delay between when the thread was
128 30 : // *notified* and when this counter was decremented. That
129 30 : // might mislead people with new work into thinking that
130 30 : // there are sleeping threads that they should try to
131 30 : // wake, when in fact there is nothing left for them to
132 30 : // do.
133 30 : self.counters.fetch_sub(1, Ordering::SeqCst);
134 30 : *state = ThreadState::Active;
135 30 :
136 30 : true
137 : } else {
138 0 : false
139 : }
140 30 : }
141 :
142 59 : fn steal(&self, rng: &mut impl Rng, skip: usize, worker: &Worker<JobSpec>) -> Option<JobSpec> {
143 59 : // announce thread as idle
144 59 : self.counters.fetch_add(256, Ordering::SeqCst);
145 :
146 : // try steal from the global queue
147 59 : loop {
148 59 : match self.queue.steal_batch_and_pop(worker) {
149 30 : crossbeam_deque::Steal::Success(job) => {
150 30 : self.metrics
151 30 : .injector_queue_depth
152 30 : .set(self.queue.len() as i64);
153 30 : // no longer idle
154 30 : self.counters.fetch_sub(256, Ordering::SeqCst);
155 30 : return Some(job);
156 : }
157 0 : crossbeam_deque::Steal::Retry => continue,
158 29 : crossbeam_deque::Steal::Empty => break,
159 : }
160 : }
161 :
162 : // try steal from our neighbours
163 29 : loop {
164 29 : let mut retry = false;
165 29 : let start = rng.gen_range(0..self.stealers.len());
166 29 : let job = (start..self.stealers.len())
167 29 : .chain(0..start)
168 29 : .filter(|i| *i != skip)
169 29 : .find_map(
170 29 : |victim| match self.stealers[victim].steal_batch_and_pop(worker) {
171 0 : crossbeam_deque::Steal::Success(job) => Some(job),
172 0 : crossbeam_deque::Steal::Empty => None,
173 : crossbeam_deque::Steal::Retry => {
174 0 : retry = true;
175 0 : None
176 : }
177 29 : },
178 29 : );
179 29 : if job.is_some() {
180 : // no longer idle
181 0 : self.counters.fetch_sub(256, Ordering::SeqCst);
182 0 : return job;
183 29 : }
184 29 : if !retry {
185 29 : return None;
186 0 : }
187 : }
188 59 : }
189 : }
190 :
191 36 : fn thread_rt(pool: Arc<ThreadPool>, worker: Worker<JobSpec>, index: usize) {
192 36 : /// interval when we should steal from the global queue
193 36 : /// so that tail latencies are managed appropriately
194 36 : const STEAL_INTERVAL: usize = 61;
195 36 :
196 36 : /// How often to reset the sketch values
197 36 : const SKETCH_RESET_INTERVAL: usize = 1021;
198 36 :
199 36 : let mut rng = SmallRng::from_entropy();
200 36 :
201 36 : // used to determine whether we should temporarily skip tasks for fairness.
202 36 : // 99% of estimates will overcount by no more than 4096 samples
203 36 : let mut sketch = CountMinSketch::with_params(1.0 / (SKETCH_RESET_INTERVAL as f64), 0.01);
204 36 :
205 36 : let (condvar, lock) = &pool.parkers[index];
206 :
207 65 : 'wait: loop {
208 65 : // wait for notification of work
209 65 : {
210 65 : let mut lock = lock.lock();
211 65 :
212 65 : // queue is empty
213 65 : pool.metrics
214 65 : .worker_queue_depth
215 65 : .set(ThreadPoolWorkerId(index), 0);
216 65 :
217 65 : // subtract 1 from idle count, add 1 to sleeping count.
218 65 : pool.counters.fetch_sub(255, Ordering::SeqCst);
219 65 :
220 65 : *lock = ThreadState::Parked;
221 65 : condvar.wait(&mut lock);
222 65 : }
223 :
224 99 : for i in 0.. {
225 99 : let Some(mut job) = worker
226 99 : .pop()
227 99 : .or_else(|| pool.steal(&mut rng, index, &worker))
228 : else {
229 29 : continue 'wait;
230 : };
231 :
232 70 : pool.metrics
233 70 : .worker_queue_depth
234 70 : .set(ThreadPoolWorkerId(index), worker.len() as i64);
235 70 :
236 70 : // receiver is closed, cancel the task
237 70 : if !job.response.is_closed() {
238 70 : let rate = sketch.inc_and_return(&job.endpoint, job.pbkdf2.cost());
239 70 :
240 70 : const P: f64 = 2000.0;
241 70 : // probability decreases as rate increases.
242 70 : // lower probability, higher chance of being skipped
243 70 : //
244 70 : // estimates (rate in terms of 4096 rounds):
245 70 : // rate = 0 => probability = 100%
246 70 : // rate = 10 => probability = 71.3%
247 70 : // rate = 50 => probability = 62.1%
248 70 : // rate = 500 => probability = 52.3%
249 70 : // rate = 1021 => probability = 49.8%
250 70 : //
251 70 : // My expectation is that the pool queue will only begin backing up at ~1000rps
252 70 : // in which case the SKETCH_RESET_INTERVAL represents 1 second. Thus, the rates above
253 70 : // are in requests per second.
254 70 : let probability = P.ln() / (P + rate as f64).ln();
255 70 : if pool.queue.len() > 32 || rng.gen_bool(probability) {
256 65 : pool.metrics
257 65 : .worker_task_turns_total
258 65 : .inc(ThreadPoolWorkerId(index));
259 65 :
260 65 : match job.pbkdf2.turn() {
261 30 : std::task::Poll::Ready(result) => {
262 30 : let _ = job.response.send(result);
263 30 : }
264 0 : std::task::Poll::Pending => worker.push(job),
265 : }
266 5 : } else {
267 5 : pool.metrics
268 5 : .worker_task_skips_total
269 5 : .inc(ThreadPoolWorkerId(index));
270 5 :
271 5 : // skip for now
272 5 : worker.push(job);
273 5 : }
274 0 : }
275 :
276 : // if we get stuck with a few long lived jobs in the queue
277 : // it's better to try and steal from the queue too for fairness
278 35 : if i % STEAL_INTERVAL == 0 {
279 29 : let _ = pool.queue.steal_batch(&worker);
280 29 : }
281 :
282 35 : if i % SKETCH_RESET_INTERVAL == 0 {
283 29 : sketch.reset();
284 29 : }
285 : }
286 : }
287 : }
288 :
289 : struct JobSpec {
290 : response: oneshot::Sender<[u8; 32]>,
291 : pbkdf2: Pbkdf2,
292 : endpoint: EndpointIdInt,
293 : }
294 :
295 : #[cfg(test)]
296 : mod tests {
297 : use crate::EndpointId;
298 :
299 : use super::*;
300 :
301 : #[tokio::test]
302 6 : async fn hash_is_correct() {
303 6 : let pool = ThreadPool::new(1);
304 6 :
305 6 : let ep = EndpointId::from("foo");
306 6 : let ep = EndpointIdInt::from(ep);
307 6 :
308 6 : let salt = [0x55; 32];
309 6 : let actual = pool
310 6 : .spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096))
311 6 : .await
312 6 : .unwrap();
313 6 :
314 6 : let expected = [
315 6 : 10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
316 6 : 178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
317 6 : ];
318 6 : assert_eq!(actual, expected);
319 6 : }
320 : }
|