Line data Source code
1 : //! Custom threadpool implementation for password hashing.
2 : //!
3 : //! Requirements:
4 : //! 1. Fairness per endpoint.
5 : //! 2. Yield support for high iteration counts.
6 :
7 : use std::{
8 : cell::RefCell,
9 : future::Future,
10 : pin::Pin,
11 : sync::{
12 : atomic::{AtomicUsize, Ordering},
13 : Arc, Weak,
14 : },
15 : task::{Context, Poll},
16 : };
17 :
18 : use futures::FutureExt;
19 : use rand::Rng;
20 : use rand::{rngs::SmallRng, SeedableRng};
21 :
22 : use crate::{
23 : intern::EndpointIdInt,
24 : metrics::{ThreadPoolMetrics, ThreadPoolWorkerId},
25 : scram::countmin::CountMinSketch,
26 : };
27 :
28 : use super::pbkdf2::Pbkdf2;
29 :
30 : pub struct ThreadPool {
31 : runtime: Option<tokio::runtime::Runtime>,
32 : pub metrics: Arc<ThreadPoolMetrics>,
33 : }
34 :
35 : /// How often to reset the sketch values
36 : const SKETCH_RESET_INTERVAL: u64 = 1021;
37 :
38 : thread_local! {
39 : static STATE: RefCell<Option<ThreadRt>> = const { RefCell::new(None) };
40 : }
41 :
42 : impl ThreadPool {
43 6 : pub fn new(n_workers: u8) -> Arc<Self> {
44 6 : // rayon would be nice here, but yielding in rayon does not work well afaict.
45 6 :
46 6 : Arc::new_cyclic(|pool| {
47 6 : let pool = pool.clone();
48 6 : let worker_id = AtomicUsize::new(0);
49 6 :
50 6 : let runtime = tokio::runtime::Builder::new_multi_thread()
51 6 : .worker_threads(n_workers as usize)
52 6 : .on_thread_start(move || {
53 6 : STATE.with_borrow_mut(|state| {
54 6 : *state = Some(ThreadRt {
55 6 : pool: pool.clone(),
56 6 : id: ThreadPoolWorkerId(worker_id.fetch_add(1, Ordering::Relaxed)),
57 6 : rng: SmallRng::from_entropy(),
58 6 : // used to determine whether we should temporarily skip tasks for fairness.
59 6 : // 99% of estimates will overcount by no more than 4096 samples
60 6 : countmin: CountMinSketch::with_params(
61 6 : 1.0 / (SKETCH_RESET_INTERVAL as f64),
62 6 : 0.01,
63 6 : ),
64 6 : tick: 0,
65 6 : });
66 6 : });
67 6 : })
68 6 : .build()
69 6 : .unwrap();
70 6 :
71 6 : Self {
72 6 : runtime: Some(runtime),
73 6 : metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
74 6 : }
75 6 : })
76 6 : }
77 :
78 5 : pub(crate) fn spawn_job(&self, endpoint: EndpointIdInt, pbkdf2: Pbkdf2) -> JobHandle {
79 5 : JobHandle(
80 5 : self.runtime
81 5 : .as_ref()
82 5 : .unwrap()
83 5 : .spawn(JobSpec { pbkdf2, endpoint }),
84 5 : )
85 5 : }
86 : }
87 :
88 : impl Drop for ThreadPool {
89 3 : fn drop(&mut self) {
90 3 : self.runtime.take().unwrap().shutdown_background();
91 3 : }
92 : }
93 :
94 : struct ThreadRt {
95 : pool: Weak<ThreadPool>,
96 : id: ThreadPoolWorkerId,
97 : rng: SmallRng,
98 : countmin: CountMinSketch,
99 : tick: u64,
100 : }
101 :
102 : impl ThreadRt {
103 8 : fn should_run(&mut self, job: &JobSpec) -> bool {
104 8 : let rate = self
105 8 : .countmin
106 8 : .inc_and_return(&job.endpoint, job.pbkdf2.cost());
107 :
108 : const P: f64 = 2000.0;
109 : // probability decreases as rate increases.
110 : // lower probability, higher chance of being skipped
111 : //
112 : // estimates (rate in terms of 4096 rounds):
113 : // rate = 0 => probability = 100%
114 : // rate = 10 => probability = 71.3%
115 : // rate = 50 => probability = 62.1%
116 : // rate = 500 => probability = 52.3%
117 : // rate = 1021 => probability = 49.8%
118 : //
119 : // My expectation is that the pool queue will only begin backing up at ~1000rps
120 : // in which case the SKETCH_RESET_INTERVAL represents 1 second. Thus, the rates above
121 : // are in requests per second.
122 8 : let probability = P.ln() / (P + rate as f64).ln();
123 8 : self.rng.gen_bool(probability)
124 8 : }
125 : }
126 :
127 : struct JobSpec {
128 : pbkdf2: Pbkdf2,
129 : endpoint: EndpointIdInt,
130 : }
131 :
132 : impl Future for JobSpec {
133 : type Output = [u8; 32];
134 :
135 8 : fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
136 8 : STATE.with_borrow_mut(|state| {
137 8 : let state = state.as_mut().expect("should be set on thread startup");
138 8 :
139 8 : state.tick = state.tick.wrapping_add(1);
140 8 : if state.tick % SKETCH_RESET_INTERVAL == 0 {
141 0 : state.countmin.reset();
142 8 : }
143 :
144 8 : if state.should_run(&self) {
145 5 : if let Some(pool) = state.pool.upgrade() {
146 5 : pool.metrics.worker_task_turns_total.inc(state.id);
147 5 : }
148 :
149 5 : match self.pbkdf2.turn() {
150 5 : Poll::Ready(result) => Poll::Ready(result),
151 : // more to do, we shall requeue
152 : Poll::Pending => {
153 0 : cx.waker().wake_by_ref();
154 0 : Poll::Pending
155 : }
156 : }
157 : } else {
158 3 : if let Some(pool) = state.pool.upgrade() {
159 3 : pool.metrics.worker_task_skips_total.inc(state.id);
160 3 : }
161 :
162 3 : cx.waker().wake_by_ref();
163 3 : Poll::Pending
164 : }
165 8 : })
166 8 : }
167 : }
168 :
169 : pub(crate) struct JobHandle(tokio::task::JoinHandle<[u8; 32]>);
170 :
171 : impl Future for JobHandle {
172 : type Output = [u8; 32];
173 :
174 10 : fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
175 10 : match self.0.poll_unpin(cx) {
176 5 : Poll::Ready(Ok(ok)) => Poll::Ready(ok),
177 0 : Poll::Ready(Err(err)) => std::panic::resume_unwind(err.into_panic()),
178 5 : Poll::Pending => Poll::Pending,
179 : }
180 10 : }
181 : }
182 :
183 : impl Drop for JobHandle {
184 5 : fn drop(&mut self) {
185 5 : self.0.abort();
186 5 : }
187 : }
188 :
189 : #[cfg(test)]
190 : mod tests {
191 : use crate::EndpointId;
192 :
193 : use super::*;
194 :
195 : #[tokio::test]
196 1 : async fn hash_is_correct() {
197 1 : let pool = ThreadPool::new(1);
198 1 :
199 1 : let ep = EndpointId::from("foo");
200 1 : let ep = EndpointIdInt::from(ep);
201 1 :
202 1 : let salt = [0x55; 32];
203 1 : let actual = pool
204 1 : .spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096))
205 1 : .await;
206 1 :
207 1 : let expected = [
208 1 : 10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
209 1 : 178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
210 1 : ];
211 1 : assert_eq!(actual, expected);
212 1 : }
213 : }
|