Line data Source code
1 : //! A set for cancelling random http connections
2 :
3 : use std::{
4 : hash::{BuildHasher, BuildHasherDefault},
5 : num::NonZeroUsize,
6 : time::Duration,
7 : };
8 :
9 : use indexmap::IndexMap;
10 : use parking_lot::Mutex;
11 : use rand::{thread_rng, Rng};
12 : use rustc_hash::FxHasher;
13 : use tokio::time::Instant;
14 : use tokio_util::sync::CancellationToken;
15 : use uuid::Uuid;
16 :
17 : type Hasher = BuildHasherDefault<FxHasher>;
18 :
19 : pub struct CancelSet {
20 : shards: Box<[Mutex<CancelShard>]>,
21 : // keyed by random uuid, fxhasher is fine
22 : hasher: Hasher,
23 : }
24 :
25 : pub(crate) struct CancelShard {
26 : tokens: IndexMap<uuid::Uuid, (Instant, CancellationToken), Hasher>,
27 : }
28 :
29 : impl CancelSet {
30 1 : pub fn new(shards: usize) -> Self {
31 1 : CancelSet {
32 1 : shards: (0..shards)
33 1 : .map(|_| {
34 0 : Mutex::new(CancelShard {
35 0 : tokens: IndexMap::with_hasher(Hasher::default()),
36 0 : })
37 1 : })
38 1 : .collect(),
39 1 : hasher: Hasher::default(),
40 1 : }
41 1 : }
42 :
43 0 : pub(crate) fn take(&self) -> Option<CancellationToken> {
44 0 : for _ in 0..4 {
45 0 : if let Some(token) = self.take_raw(thread_rng().gen()) {
46 0 : return Some(token);
47 0 : }
48 0 : tracing::trace!("failed to get cancel token");
49 : }
50 0 : None
51 0 : }
52 :
53 0 : pub(crate) fn take_raw(&self, rng: usize) -> Option<CancellationToken> {
54 0 : NonZeroUsize::new(self.shards.len())
55 0 : .and_then(|len| self.shards[rng % len].lock().take(rng / len))
56 0 : }
57 :
58 0 : pub(crate) fn insert(&self, id: uuid::Uuid, token: CancellationToken) -> CancelGuard<'_> {
59 0 : let shard = NonZeroUsize::new(self.shards.len()).map(|len| {
60 0 : let hash = self.hasher.hash_one(id) as usize;
61 0 : let shard = &self.shards[hash % len];
62 0 : shard.lock().insert(id, token);
63 0 : shard
64 0 : });
65 0 : CancelGuard { shard, id }
66 0 : }
67 : }
68 :
69 : impl CancelShard {
70 0 : fn take(&mut self, rng: usize) -> Option<CancellationToken> {
71 0 : NonZeroUsize::new(self.tokens.len()).and_then(|len| {
72 0 : // 10 second grace period so we don't cancel new connections
73 0 : if self.tokens.get_index(rng % len)?.1 .0.elapsed() < Duration::from_secs(10) {
74 0 : return None;
75 0 : }
76 :
77 0 : let (_key, (_insert, token)) = self.tokens.swap_remove_index(rng % len)?;
78 0 : Some(token)
79 0 : })
80 0 : }
81 :
82 0 : fn remove(&mut self, id: uuid::Uuid) {
83 0 : self.tokens.swap_remove(&id);
84 0 : }
85 :
86 0 : fn insert(&mut self, id: uuid::Uuid, token: CancellationToken) {
87 0 : self.tokens.insert(id, (Instant::now(), token));
88 0 : }
89 : }
90 :
91 : pub(crate) struct CancelGuard<'a> {
92 : shard: Option<&'a Mutex<CancelShard>>,
93 : id: Uuid,
94 : }
95 :
96 : impl Drop for CancelGuard<'_> {
97 0 : fn drop(&mut self) {
98 0 : if let Some(shard) = self.shard {
99 0 : shard.lock().remove(self.id);
100 0 : }
101 0 : }
102 : }
|