Line data Source code
1 : //! A set for cancelling random http connections
2 :
3 : use std::hash::{BuildHasher, BuildHasherDefault};
4 : use std::num::NonZeroUsize;
5 : use std::time::Duration;
6 :
7 : use indexmap::IndexMap;
8 : use parking_lot::Mutex;
9 : use rand::{thread_rng, Rng};
10 : use rustc_hash::FxHasher;
11 : use tokio::time::Instant;
12 : use tokio_util::sync::CancellationToken;
13 : use uuid::Uuid;
14 :
15 : type Hasher = BuildHasherDefault<FxHasher>;
16 :
17 : pub struct CancelSet {
18 : shards: Box<[Mutex<CancelShard>]>,
19 : // keyed by random uuid, fxhasher is fine
20 : hasher: Hasher,
21 : }
22 :
23 : pub(crate) struct CancelShard {
24 : tokens: IndexMap<uuid::Uuid, (Instant, CancellationToken), Hasher>,
25 : }
26 :
27 : impl CancelSet {
28 1 : pub fn new(shards: usize) -> Self {
29 1 : CancelSet {
30 1 : shards: (0..shards)
31 1 : .map(|_| {
32 0 : Mutex::new(CancelShard {
33 0 : tokens: IndexMap::with_hasher(Hasher::default()),
34 0 : })
35 1 : })
36 1 : .collect(),
37 1 : hasher: Hasher::default(),
38 1 : }
39 1 : }
40 :
41 0 : pub(crate) fn take(&self) -> Option<CancellationToken> {
42 0 : for _ in 0..4 {
43 0 : if let Some(token) = self.take_raw(thread_rng().gen()) {
44 0 : return Some(token);
45 0 : }
46 0 : tracing::trace!("failed to get cancel token");
47 : }
48 0 : None
49 0 : }
50 :
51 0 : pub(crate) fn take_raw(&self, rng: usize) -> Option<CancellationToken> {
52 0 : NonZeroUsize::new(self.shards.len())
53 0 : .and_then(|len| self.shards[rng % len].lock().take(rng / len))
54 0 : }
55 :
56 0 : pub(crate) fn insert(&self, id: uuid::Uuid, token: CancellationToken) -> CancelGuard<'_> {
57 0 : let shard = NonZeroUsize::new(self.shards.len()).map(|len| {
58 0 : let hash = self.hasher.hash_one(id) as usize;
59 0 : let shard = &self.shards[hash % len];
60 0 : shard.lock().insert(id, token);
61 0 : shard
62 0 : });
63 0 : CancelGuard { shard, id }
64 0 : }
65 : }
66 :
67 : impl CancelShard {
68 0 : fn take(&mut self, rng: usize) -> Option<CancellationToken> {
69 0 : NonZeroUsize::new(self.tokens.len()).and_then(|len| {
70 0 : // 10 second grace period so we don't cancel new connections
71 0 : if self.tokens.get_index(rng % len)?.1 .0.elapsed() < Duration::from_secs(10) {
72 0 : return None;
73 0 : }
74 :
75 0 : let (_key, (_insert, token)) = self.tokens.swap_remove_index(rng % len)?;
76 0 : Some(token)
77 0 : })
78 0 : }
79 :
80 0 : fn remove(&mut self, id: uuid::Uuid) {
81 0 : self.tokens.swap_remove(&id);
82 0 : }
83 :
84 0 : fn insert(&mut self, id: uuid::Uuid, token: CancellationToken) {
85 0 : self.tokens.insert(id, (Instant::now(), token));
86 0 : }
87 : }
88 :
89 : pub(crate) struct CancelGuard<'a> {
90 : shard: Option<&'a Mutex<CancelShard>>,
91 : id: Uuid,
92 : }
93 :
94 : impl Drop for CancelGuard<'_> {
95 0 : fn drop(&mut self) {
96 0 : if let Some(shard) = self.shard {
97 0 : shard.lock().remove(self.id);
98 0 : }
99 0 : }
100 : }
|