Line data Source code
1 : use std::fmt::Display;
2 : use std::time::Instant;
3 : use std::{collections::HashMap, sync::Arc};
4 :
5 : use std::time::Duration;
6 :
7 : use crate::service::RECONCILE_TIMEOUT;
8 :
9 : const LOCK_TIMEOUT_ALERT_THRESHOLD: Duration = RECONCILE_TIMEOUT;
10 :
11 : /// A wrapper around `OwnedRwLockWriteGuard` used for tracking the
12 : /// operation that holds the lock, and print a warning if it exceeds
13 : /// the LOCK_TIMEOUT_ALERT_THRESHOLD time
14 : pub struct TracingExclusiveGuard<T: Display> {
15 : guard: tokio::sync::OwnedRwLockWriteGuard<Option<T>>,
16 : start: Instant,
17 : }
18 :
19 : impl<T: Display> TracingExclusiveGuard<T> {
20 1 : pub fn new(guard: tokio::sync::OwnedRwLockWriteGuard<Option<T>>) -> Self {
21 1 : Self {
22 1 : guard,
23 1 : start: Instant::now(),
24 1 : }
25 1 : }
26 : }
27 :
28 : impl<T: Display> Drop for TracingExclusiveGuard<T> {
29 1 : fn drop(&mut self) {
30 1 : let duration = self.start.elapsed();
31 1 : if duration > LOCK_TIMEOUT_ALERT_THRESHOLD {
32 0 : tracing::warn!(
33 0 : "Exclusive lock by {} was held for {:?}",
34 0 : self.guard.as_ref().unwrap(),
35 : duration
36 : );
37 1 : }
38 1 : *self.guard = None;
39 1 : }
40 : }
41 :
42 : // A wrapper around `OwnedRwLockReadGuard` used for tracking the
43 : /// operation that holds the lock, and print a warning if it exceeds
44 : /// the LOCK_TIMEOUT_ALERT_THRESHOLD time
45 : pub struct TracingSharedGuard<T: Display> {
46 : _guard: tokio::sync::OwnedRwLockReadGuard<Option<T>>,
47 : operation: T,
48 : start: Instant,
49 : }
50 :
51 : impl<T: Display> TracingSharedGuard<T> {
52 3 : pub fn new(guard: tokio::sync::OwnedRwLockReadGuard<Option<T>>, operation: T) -> Self {
53 3 : Self {
54 3 : _guard: guard,
55 3 : operation,
56 3 : start: Instant::now(),
57 3 : }
58 3 : }
59 : }
60 :
61 : impl<T: Display> Drop for TracingSharedGuard<T> {
62 3 : fn drop(&mut self) {
63 3 : let duration = self.start.elapsed();
64 3 : if duration > LOCK_TIMEOUT_ALERT_THRESHOLD {
65 0 : tracing::warn!(
66 0 : "Shared lock by {} was held for {:?}",
67 : self.operation,
68 : duration
69 : );
70 3 : }
71 3 : }
72 : }
73 :
74 : /// A map of locks covering some arbitrary identifiers. Useful if you have a collection of objects but don't
75 : /// want to embed a lock in each one, or if your locking granularity is different to your object granularity.
76 : /// For example, used in the storage controller where the objects are tenant shards, but sometimes locking
77 : /// is needed at a tenant-wide granularity.
78 : pub(crate) struct IdLockMap<T, I>
79 : where
80 : T: Eq + PartialEq + std::hash::Hash,
81 : {
82 : /// A synchronous lock for getting/setting the async locks that our callers will wait on.
83 : entities: std::sync::Mutex<std::collections::HashMap<T, Arc<tokio::sync::RwLock<Option<I>>>>>,
84 : }
85 :
86 : impl<T, I> IdLockMap<T, I>
87 : where
88 : T: Eq + PartialEq + std::hash::Hash,
89 : I: Display,
90 : {
91 3 : pub(crate) fn shared(
92 3 : &self,
93 3 : key: T,
94 3 : operation: I,
95 3 : ) -> impl std::future::Future<Output = TracingSharedGuard<I>> {
96 3 : let mut locked = self.entities.lock().unwrap();
97 3 : let entry = locked.entry(key).or_default().clone();
98 3 : async move { TracingSharedGuard::new(entry.read_owned().await, operation) }
99 3 : }
100 :
101 2 : pub(crate) fn exclusive(
102 2 : &self,
103 2 : key: T,
104 2 : operation: I,
105 2 : ) -> impl std::future::Future<Output = TracingExclusiveGuard<I>> {
106 2 : let mut locked = self.entities.lock().unwrap();
107 2 : let entry = locked.entry(key).or_default().clone();
108 2 : async move {
109 2 : let mut guard = TracingExclusiveGuard::new(entry.write_owned().await);
110 1 : *guard.guard = Some(operation);
111 1 : guard
112 1 : }
113 2 : }
114 :
115 : /// Rather than building a lock guard that re-takes the [`Self::entities`] lock, we just do
116 : /// periodic housekeeping to avoid the map growing indefinitely
117 0 : pub(crate) fn housekeeping(&self) {
118 0 : let mut locked = self.entities.lock().unwrap();
119 0 : locked.retain(|_k, entry| entry.try_write().is_err())
120 0 : }
121 : }
122 :
123 : impl<T, I> Default for IdLockMap<T, I>
124 : where
125 : T: Eq + PartialEq + std::hash::Hash,
126 : {
127 2 : fn default() -> Self {
128 2 : Self {
129 2 : entities: std::sync::Mutex::new(HashMap::new()),
130 2 : }
131 2 : }
132 : }
133 :
134 0 : pub async fn trace_exclusive_lock<
135 0 : T: Clone + Display + Eq + PartialEq + std::hash::Hash,
136 0 : I: Clone + Display,
137 0 : >(
138 0 : op_locks: &IdLockMap<T, I>,
139 0 : key: T,
140 0 : operation: I,
141 0 : ) -> TracingExclusiveGuard<I> {
142 0 : let start = Instant::now();
143 0 : let guard = op_locks.exclusive(key.clone(), operation.clone()).await;
144 :
145 0 : let duration = start.elapsed();
146 0 : if duration > LOCK_TIMEOUT_ALERT_THRESHOLD {
147 0 : tracing::warn!(
148 0 : "Operation {} on key {} has waited {:?} for exclusive lock",
149 : operation,
150 : key,
151 : duration
152 : );
153 0 : }
154 :
155 0 : guard
156 0 : }
157 :
158 0 : pub async fn trace_shared_lock<
159 0 : T: Clone + Display + Eq + PartialEq + std::hash::Hash,
160 0 : I: Clone + Display,
161 0 : >(
162 0 : op_locks: &IdLockMap<T, I>,
163 0 : key: T,
164 0 : operation: I,
165 0 : ) -> TracingSharedGuard<I> {
166 0 : let start = Instant::now();
167 0 : let guard = op_locks.shared(key.clone(), operation.clone()).await;
168 :
169 0 : let duration = start.elapsed();
170 0 : if duration > LOCK_TIMEOUT_ALERT_THRESHOLD {
171 0 : tracing::warn!(
172 0 : "Operation {} on key {} has waited {:?} for shared lock",
173 : operation,
174 : key,
175 : duration
176 : );
177 0 : }
178 :
179 0 : guard
180 0 : }
181 :
182 : #[cfg(test)]
183 : mod tests {
184 : use super::IdLockMap;
185 :
186 0 : #[derive(Clone, Debug, strum_macros::Display, PartialEq)]
187 : enum Operations {
188 : Op1,
189 : Op2,
190 : }
191 :
192 : #[tokio::test]
193 1 : async fn multiple_shared_locks() {
194 1 : let id_lock_map: IdLockMap<i32, Operations> = IdLockMap::default();
195 1 :
196 1 : let shared_lock_1 = id_lock_map.shared(1, Operations::Op1).await;
197 1 : let shared_lock_2 = id_lock_map.shared(1, Operations::Op2).await;
198 1 :
199 1 : assert_eq!(shared_lock_1.operation, Operations::Op1);
200 1 : assert_eq!(shared_lock_2.operation, Operations::Op2);
201 1 : }
202 :
203 : #[tokio::test]
204 1 : async fn exclusive_locks() {
205 1 : let id_lock_map = IdLockMap::default();
206 1 : let resource_id = 1;
207 1 :
208 1 : {
209 1 : let _ex_lock = id_lock_map.exclusive(resource_id, Operations::Op1).await;
210 1 : assert_eq!(_ex_lock.guard.clone().unwrap(), Operations::Op1);
211 1 :
212 1 : let _ex_lock_2 = tokio::time::timeout(
213 1 : tokio::time::Duration::from_millis(1),
214 1 : id_lock_map.exclusive(resource_id, Operations::Op2),
215 1 : )
216 1 : .await;
217 1 : assert!(_ex_lock_2.is_err());
218 1 : }
219 1 :
220 1 : let shared_lock_1 = id_lock_map.shared(resource_id, Operations::Op1).await;
221 1 : assert_eq!(shared_lock_1.operation, Operations::Op1);
222 1 : }
223 : }
|