LCOV - code coverage report
Current view: top level - storage_controller/src - id_lock_map.rs (source / functions) Coverage Total Hit
Test: 3eba1babe267649f8cebefc91c236589db030548.info Lines: 64.2 % 123 79
Test Date: 2024-11-22 12:36:12 Functions: 33.3 % 39 13

            Line data    Source code
       1              : use std::fmt::Display;
       2              : use std::time::Instant;
       3              : use std::{collections::HashMap, sync::Arc};
       4              : 
       5              : use std::time::Duration;
       6              : 
       7              : use crate::service::RECONCILE_TIMEOUT;
       8              : 
       9              : const LOCK_TIMEOUT_ALERT_THRESHOLD: Duration = RECONCILE_TIMEOUT;
      10              : 
      11              : /// A wrapper around `OwnedRwLockWriteGuard` used for tracking the
      12              : /// operation that holds the lock, and print a warning if it exceeds
      13              : /// the LOCK_TIMEOUT_ALERT_THRESHOLD time
      14              : pub struct TracingExclusiveGuard<T: Display> {
      15              :     guard: tokio::sync::OwnedRwLockWriteGuard<Option<T>>,
      16              :     start: Instant,
      17              : }
      18              : 
      19              : impl<T: Display> TracingExclusiveGuard<T> {
      20            1 :     pub fn new(guard: tokio::sync::OwnedRwLockWriteGuard<Option<T>>) -> Self {
      21            1 :         Self {
      22            1 :             guard,
      23            1 :             start: Instant::now(),
      24            1 :         }
      25            1 :     }
      26              : }
      27              : 
      28              : impl<T: Display> Drop for TracingExclusiveGuard<T> {
      29            1 :     fn drop(&mut self) {
      30            1 :         let duration = self.start.elapsed();
      31            1 :         if duration > LOCK_TIMEOUT_ALERT_THRESHOLD {
      32            0 :             tracing::warn!(
      33            0 :                 "Exclusive lock by {} was held for {:?}",
      34            0 :                 self.guard.as_ref().unwrap(),
      35              :                 duration
      36              :             );
      37            1 :         }
      38            1 :         *self.guard = None;
      39            1 :     }
      40              : }
      41              : 
      42              : // A wrapper around `OwnedRwLockReadGuard` used for tracking the
      43              : /// operation that holds the lock, and print a warning if it exceeds
      44              : /// the LOCK_TIMEOUT_ALERT_THRESHOLD time
      45              : pub struct TracingSharedGuard<T: Display> {
      46              :     _guard: tokio::sync::OwnedRwLockReadGuard<Option<T>>,
      47              :     operation: T,
      48              :     start: Instant,
      49              : }
      50              : 
      51              : impl<T: Display> TracingSharedGuard<T> {
      52            3 :     pub fn new(guard: tokio::sync::OwnedRwLockReadGuard<Option<T>>, operation: T) -> Self {
      53            3 :         Self {
      54            3 :             _guard: guard,
      55            3 :             operation,
      56            3 :             start: Instant::now(),
      57            3 :         }
      58            3 :     }
      59              : }
      60              : 
      61              : impl<T: Display> Drop for TracingSharedGuard<T> {
      62            3 :     fn drop(&mut self) {
      63            3 :         let duration = self.start.elapsed();
      64            3 :         if duration > LOCK_TIMEOUT_ALERT_THRESHOLD {
      65            0 :             tracing::warn!(
      66            0 :                 "Shared lock by {} was held for {:?}",
      67              :                 self.operation,
      68              :                 duration
      69              :             );
      70            3 :         }
      71            3 :     }
      72              : }
      73              : 
      74              : /// A map of locks covering some arbitrary identifiers. Useful if you have a collection of objects but don't
      75              : /// want to embed a lock in each one, or if your locking granularity is different to your object granularity.
      76              : /// For example, used in the storage controller where the objects are tenant shards, but sometimes locking
      77              : /// is needed at a tenant-wide granularity.
      78              : pub(crate) struct IdLockMap<T, I>
      79              : where
      80              :     T: Eq + PartialEq + std::hash::Hash,
      81              : {
      82              :     /// A synchronous lock for getting/setting the async locks that our callers will wait on.
      83              :     entities: std::sync::Mutex<std::collections::HashMap<T, Arc<tokio::sync::RwLock<Option<I>>>>>,
      84              : }
      85              : 
      86              : impl<T, I> IdLockMap<T, I>
      87              : where
      88              :     T: Eq + PartialEq + std::hash::Hash,
      89              :     I: Display,
      90              : {
      91            3 :     pub(crate) fn shared(
      92            3 :         &self,
      93            3 :         key: T,
      94            3 :         operation: I,
      95            3 :     ) -> impl std::future::Future<Output = TracingSharedGuard<I>> {
      96            3 :         let mut locked = self.entities.lock().unwrap();
      97            3 :         let entry = locked.entry(key).or_default().clone();
      98            3 :         async move { TracingSharedGuard::new(entry.read_owned().await, operation) }
      99            3 :     }
     100              : 
     101            2 :     pub(crate) fn exclusive(
     102            2 :         &self,
     103            2 :         key: T,
     104            2 :         operation: I,
     105            2 :     ) -> impl std::future::Future<Output = TracingExclusiveGuard<I>> {
     106            2 :         let mut locked = self.entities.lock().unwrap();
     107            2 :         let entry = locked.entry(key).or_default().clone();
     108            2 :         async move {
     109            2 :             let mut guard = TracingExclusiveGuard::new(entry.write_owned().await);
     110            1 :             *guard.guard = Some(operation);
     111            1 :             guard
     112            1 :         }
     113            2 :     }
     114              : 
     115              :     /// Rather than building a lock guard that re-takes the [`Self::entities`] lock, we just do
     116              :     /// periodic housekeeping to avoid the map growing indefinitely
     117            0 :     pub(crate) fn housekeeping(&self) {
     118            0 :         let mut locked = self.entities.lock().unwrap();
     119            0 :         locked.retain(|_k, entry| entry.try_write().is_err())
     120            0 :     }
     121              : }
     122              : 
     123              : impl<T, I> Default for IdLockMap<T, I>
     124              : where
     125              :     T: Eq + PartialEq + std::hash::Hash,
     126              : {
     127            2 :     fn default() -> Self {
     128            2 :         Self {
     129            2 :             entities: std::sync::Mutex::new(HashMap::new()),
     130            2 :         }
     131            2 :     }
     132              : }
     133              : 
     134            0 : pub async fn trace_exclusive_lock<
     135            0 :     T: Clone + Display + Eq + PartialEq + std::hash::Hash,
     136            0 :     I: Clone + Display,
     137            0 : >(
     138            0 :     op_locks: &IdLockMap<T, I>,
     139            0 :     key: T,
     140            0 :     operation: I,
     141            0 : ) -> TracingExclusiveGuard<I> {
     142            0 :     let start = Instant::now();
     143            0 :     let guard = op_locks.exclusive(key.clone(), operation.clone()).await;
     144              : 
     145            0 :     let duration = start.elapsed();
     146            0 :     if duration > LOCK_TIMEOUT_ALERT_THRESHOLD {
     147            0 :         tracing::warn!(
     148            0 :             "Operation {} on key {} has waited {:?} for exclusive lock",
     149              :             operation,
     150              :             key,
     151              :             duration
     152              :         );
     153            0 :     }
     154              : 
     155            0 :     guard
     156            0 : }
     157              : 
     158            0 : pub async fn trace_shared_lock<
     159            0 :     T: Clone + Display + Eq + PartialEq + std::hash::Hash,
     160            0 :     I: Clone + Display,
     161            0 : >(
     162            0 :     op_locks: &IdLockMap<T, I>,
     163            0 :     key: T,
     164            0 :     operation: I,
     165            0 : ) -> TracingSharedGuard<I> {
     166            0 :     let start = Instant::now();
     167            0 :     let guard = op_locks.shared(key.clone(), operation.clone()).await;
     168              : 
     169            0 :     let duration = start.elapsed();
     170            0 :     if duration > LOCK_TIMEOUT_ALERT_THRESHOLD {
     171            0 :         tracing::warn!(
     172            0 :             "Operation {} on key {} has waited {:?} for shared lock",
     173              :             operation,
     174              :             key,
     175              :             duration
     176              :         );
     177            0 :     }
     178              : 
     179            0 :     guard
     180            0 : }
     181              : 
     182              : #[cfg(test)]
     183              : mod tests {
     184              :     use super::IdLockMap;
     185              : 
     186            0 :     #[derive(Clone, Debug, strum_macros::Display, PartialEq)]
     187              :     enum Operations {
     188              :         Op1,
     189              :         Op2,
     190              :     }
     191              : 
     192              :     #[tokio::test]
     193            1 :     async fn multiple_shared_locks() {
     194            1 :         let id_lock_map: IdLockMap<i32, Operations> = IdLockMap::default();
     195            1 : 
     196            1 :         let shared_lock_1 = id_lock_map.shared(1, Operations::Op1).await;
     197            1 :         let shared_lock_2 = id_lock_map.shared(1, Operations::Op2).await;
     198            1 : 
     199            1 :         assert_eq!(shared_lock_1.operation, Operations::Op1);
     200            1 :         assert_eq!(shared_lock_2.operation, Operations::Op2);
     201            1 :     }
     202              : 
     203              :     #[tokio::test]
     204            1 :     async fn exclusive_locks() {
     205            1 :         let id_lock_map = IdLockMap::default();
     206            1 :         let resource_id = 1;
     207            1 : 
     208            1 :         {
     209            1 :             let _ex_lock = id_lock_map.exclusive(resource_id, Operations::Op1).await;
     210            1 :             assert_eq!(_ex_lock.guard.clone().unwrap(), Operations::Op1);
     211            1 : 
     212            1 :             let _ex_lock_2 = tokio::time::timeout(
     213            1 :                 tokio::time::Duration::from_millis(1),
     214            1 :                 id_lock_map.exclusive(resource_id, Operations::Op2),
     215            1 :             )
     216            1 :             .await;
     217            1 :             assert!(_ex_lock_2.is_err());
     218            1 :         }
     219            1 : 
     220            1 :         let shared_lock_1 = id_lock_map.shared(resource_id, Operations::Op1).await;
     221            1 :         assert_eq!(shared_lock_1.operation, Operations::Op1);
     222            1 :     }
     223              : }
        

Generated by: LCOV version 2.1-beta