LCOV - code coverage report
Current view: top level - libs/neon-shmem/src - hash.rs (source / functions) Coverage Total Hit
Test: 4be46b1c0003aa3bbac9ade362c676b419df4c20.info Lines: 87.5 % 321 281
Test Date: 2025-07-22 17:50:06 Functions: 86.2 % 29 25

            Line data    Source code
       1              : //! Resizable hash table implementation on top of byte-level storage (either a [`ShmemHandle`] or a fixed byte array).
       2              : //!
       3              : //! This hash table has two major components: the bucket array and the dictionary. Each bucket within the
       4              : //! bucket array contains a `Option<(K, V)>` and an index of another bucket. In this way there is both an
       5              : //! implicit freelist within the bucket array (`None` buckets point to other `None` entries) and various hash
       6              : //! chains within the bucket array (a Some bucket will point to other Some buckets that had the same hash).
       7              : //!
       8              : //! Buckets are never moved unless they are within a region that is being shrunk, and so the actual hash-
       9              : //! dependent component is done with the dictionary. When a new key is inserted into the map, a position
      10              : //! within the dictionary is decided based on its hash, the data is inserted into an empty bucket based
      11              : //! off of the freelist, and then the index of said bucket is placed in the dictionary.
      12              : //!
      13              : //! This map is resizable (if initialized on top of a [`ShmemHandle`]). Both growing and shrinking happen
      14              : //! in-place and are at a high level achieved by expanding/reducing the bucket array and rebuilding the
      15              : //! dictionary by rehashing all keys.
      16              : //!
      17              : //! Concurrency is managed very simply: the entire map is guarded by one shared-memory RwLock.
      18              : 
      19              : use std::hash::{BuildHasher, Hash};
      20              : use std::mem::MaybeUninit;
      21              : 
      22              : use crate::shmem::ShmemHandle;
      23              : use crate::{shmem, sync::*};
      24              : 
      25              : mod core;
      26              : pub mod entry;
      27              : 
      28              : #[cfg(test)]
      29              : mod tests;
      30              : 
      31              : use core::{Bucket, CoreHashMap, INVALID_POS};
      32              : use entry::{Entry, OccupiedEntry, PrevPos, VacantEntry};
      33              : 
      34              : use thiserror::Error;
      35              : 
      36              : /// Error type for a hashmap shrink operation.
      37              : #[derive(Error, Debug)]
      38              : pub enum HashMapShrinkError {
      39              :     /// There was an error encountered while resizing the memory area.
      40              :     #[error("shmem resize failed: {0}")]
      41              :     ResizeError(shmem::Error),
      42              :     /// Occupied entries in to-be-shrunk space were encountered beginning at the given index.
      43              :     #[error("occupied entry in deallocated space found at {0}")]
      44              :     RemainingEntries(usize),
      45              : }
      46              : 
      47              : /// This represents a hash table that (possibly) lives in shared memory.
      48              : /// If a new process is launched with fork(), the child process inherits
      49              : /// this struct.
      50              : #[must_use]
      51              : pub struct HashMapInit<'a, K, V, S = rustc_hash::FxBuildHasher> {
      52              :     shmem_handle: Option<ShmemHandle>,
      53              :     shared_ptr: *mut HashMapShared<'a, K, V>,
      54              :     shared_size: usize,
      55              :     hasher: S,
      56              :     num_buckets: u32,
      57              : }
      58              : 
      59              : /// This is a per-process handle to a hash table that (possibly) lives in shared memory.
      60              : /// If a child process is launched with fork(), the child process should
      61              : /// get its own HashMapAccess by calling HashMapInit::attach_writer/reader().
      62              : ///
      63              : /// XXX: We're not making use of it at the moment, but this struct could
      64              : /// hold process-local information in the future.
      65              : pub struct HashMapAccess<'a, K, V, S = rustc_hash::FxBuildHasher> {
      66              :     shmem_handle: Option<ShmemHandle>,
      67              :     shared_ptr: *mut HashMapShared<'a, K, V>,
      68              :     hasher: S,
      69              : }
      70              : 
      71              : unsafe impl<K: Sync, V: Sync, S> Sync for HashMapAccess<'_, K, V, S> {}
      72              : unsafe impl<K: Send, V: Send, S> Send for HashMapAccess<'_, K, V, S> {}
      73              : 
      74              : impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> {
      75              :     /// Change the 'hasher' used by the hash table.
      76              :     ///
      77              :     /// NOTE: This must be called right after creating the hash table,
      78              :     /// before inserting any entries and before calling attach_writer/reader.
      79              :     /// Otherwise different accessors could be using different hash function,
      80              :     /// with confusing results.
      81            0 :     pub fn with_hasher<T: BuildHasher>(self, hasher: T) -> HashMapInit<'a, K, V, T> {
      82            0 :         HashMapInit {
      83            0 :             hasher,
      84            0 :             shmem_handle: self.shmem_handle,
      85            0 :             shared_ptr: self.shared_ptr,
      86            0 :             shared_size: self.shared_size,
      87            0 :             num_buckets: self.num_buckets,
      88            0 :         }
      89            0 :     }
      90              : 
      91              :     /// Loosely (over)estimate the size needed to store a hash table with `num_buckets` buckets.
      92           60 :     pub fn estimate_size(num_buckets: u32) -> usize {
      93              :         // add some margin to cover alignment etc.
      94           60 :         CoreHashMap::<K, V>::estimate_size(num_buckets) + size_of::<HashMapShared<K, V>>() + 1000
      95           60 :     }
      96              : 
      97           26 :     fn new(
      98           26 :         num_buckets: u32,
      99           26 :         shmem_handle: Option<ShmemHandle>,
     100           26 :         area_ptr: *mut u8,
     101           26 :         area_size: usize,
     102           26 :         hasher: S,
     103           26 :     ) -> Self {
     104           26 :         let mut ptr: *mut u8 = area_ptr;
     105           26 :         let end_ptr: *mut u8 = unsafe { ptr.add(area_size) };
     106              : 
     107              :         // carve out area for the One Big Lock (TM) and the HashMapShared.
     108           26 :         ptr = unsafe { ptr.add(ptr.align_offset(align_of::<libc::pthread_rwlock_t>())) };
     109           26 :         let raw_lock_ptr = ptr;
     110           26 :         ptr = unsafe { ptr.add(size_of::<libc::pthread_rwlock_t>()) };
     111           26 :         ptr = unsafe { ptr.add(ptr.align_offset(align_of::<HashMapShared<K, V>>())) };
     112           26 :         let shared_ptr: *mut HashMapShared<K, V> = ptr.cast();
     113           26 :         ptr = unsafe { ptr.add(size_of::<HashMapShared<K, V>>()) };
     114              : 
     115              :         // carve out the buckets
     116           26 :         ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<core::Bucket<K, V>>())) };
     117           26 :         let buckets_ptr = ptr;
     118           26 :         ptr = unsafe { ptr.add(size_of::<core::Bucket<K, V>>() * num_buckets as usize) };
     119              : 
     120              :         // use remaining space for the dictionary
     121           26 :         ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<u32>())) };
     122           26 :         assert!(ptr.addr() < end_ptr.addr());
     123           26 :         let dictionary_ptr = ptr;
     124           26 :         let dictionary_size = unsafe { end_ptr.byte_offset_from(ptr) / size_of::<u32>() as isize };
     125           26 :         assert!(dictionary_size > 0);
     126              : 
     127           26 :         let buckets =
     128           26 :             unsafe { std::slice::from_raw_parts_mut(buckets_ptr.cast(), num_buckets as usize) };
     129           26 :         let dictionary = unsafe {
     130           26 :             std::slice::from_raw_parts_mut(dictionary_ptr.cast(), dictionary_size as usize)
     131              :         };
     132              : 
     133           26 :         let hashmap = CoreHashMap::new(buckets, dictionary);
     134           26 :         unsafe {
     135           26 :             let lock = RwLock::from_raw(PthreadRwLock::new(raw_lock_ptr.cast()), hashmap);
     136           26 :             std::ptr::write(shared_ptr, lock);
     137           26 :         }
     138              : 
     139           26 :         Self {
     140           26 :             num_buckets,
     141           26 :             shmem_handle,
     142           26 :             shared_ptr,
     143           26 :             shared_size: area_size,
     144           26 :             hasher,
     145           26 :         }
     146           26 :     }
     147              : 
     148              :     /// Attach to a hash table for writing.
     149           26 :     pub fn attach_writer(self) -> HashMapAccess<'a, K, V, S> {
     150           26 :         HashMapAccess {
     151           26 :             shmem_handle: self.shmem_handle,
     152           26 :             shared_ptr: self.shared_ptr,
     153           26 :             hasher: self.hasher,
     154           26 :         }
     155           26 :     }
     156              : 
     157              :     /// Initialize a table for reading. Currently identical to [`HashMapInit::attach_writer`].
     158              :     ///
     159              :     /// This is a holdover from a previous implementation and is being kept around for
     160              :     /// backwards compatibility reasons.
     161            0 :     pub fn attach_reader(self) -> HashMapAccess<'a, K, V, S> {
     162            0 :         self.attach_writer()
     163            0 :     }
     164              : }
     165              : 
     166              : /// Hash table data that is actually stored in the shared memory area.
     167              : ///
     168              : /// NOTE: We carve out the parts from a contiguous chunk. Growing and shrinking the hash table
     169              : /// relies on the memory layout! The data structures are laid out in the contiguous shared memory
     170              : /// area as follows:
     171              : ///
     172              : /// [`libc::pthread_rwlock_t`]
     173              : /// [`HashMapShared`]
     174              : /// buckets
     175              : /// dictionary
     176              : ///
     177              : /// In between the above parts, there can be padding bytes to align the parts correctly.
     178              : type HashMapShared<'a, K, V> = RwLock<CoreHashMap<'a, K, V>>;
     179              : 
     180              : impl<'a, K, V> HashMapInit<'a, K, V, rustc_hash::FxBuildHasher>
     181              : where
     182              :     K: Clone + Hash + Eq,
     183              : {
     184              :     /// Place the hash table within a user-supplied fixed memory area.
     185            1 :     pub fn with_fixed(num_buckets: u32, area: &'a mut [MaybeUninit<u8>]) -> Self {
     186            1 :         Self::new(
     187            1 :             num_buckets,
     188            1 :             None,
     189            1 :             area.as_mut_ptr().cast(),
     190            1 :             area.len(),
     191            1 :             rustc_hash::FxBuildHasher,
     192              :         )
     193            1 :     }
     194              : 
     195              :     /// Place a new hash map in the given shared memory area
     196              :     ///
     197              :     /// # Panics
     198              :     /// Will panic on failure to resize area to expected map size.
     199            0 :     pub fn with_shmem(num_buckets: u32, shmem: ShmemHandle) -> Self {
     200            0 :         let size = Self::estimate_size(num_buckets);
     201            0 :         shmem
     202            0 :             .set_size(size)
     203            0 :             .expect("could not resize shared memory area");
     204            0 :         let ptr = shmem.data_ptr.as_ptr().cast();
     205            0 :         Self::new(
     206            0 :             num_buckets,
     207            0 :             Some(shmem),
     208            0 :             ptr,
     209            0 :             size,
     210            0 :             rustc_hash::FxBuildHasher,
     211              :         )
     212            0 :     }
     213              : 
     214              :     /// Make a resizable hash map within a new shared memory area with the given name.
     215           25 :     pub fn new_resizeable_named(num_buckets: u32, max_buckets: u32, name: &str) -> Self {
     216           25 :         let size = Self::estimate_size(num_buckets);
     217           25 :         let max_size = Self::estimate_size(max_buckets);
     218           25 :         let shmem =
     219           25 :             ShmemHandle::new(name, size, max_size).expect("failed to make shared memory area");
     220           25 :         let ptr = shmem.data_ptr.as_ptr().cast();
     221              : 
     222           25 :         Self::new(
     223           25 :             num_buckets,
     224           25 :             Some(shmem),
     225           25 :             ptr,
     226           25 :             size,
     227           25 :             rustc_hash::FxBuildHasher,
     228              :         )
     229           25 :     }
     230              : 
     231              :     /// Make a resizable hash map within a new anonymous shared memory area.
     232            0 :     pub fn new_resizeable(num_buckets: u32, max_buckets: u32) -> Self {
     233              :         use std::sync::atomic::{AtomicUsize, Ordering};
     234              :         static COUNTER: AtomicUsize = AtomicUsize::new(0);
     235            0 :         let val = COUNTER.fetch_add(1, Ordering::Relaxed);
     236            0 :         let name = format!("neon_shmem_hmap{val}");
     237            0 :         Self::new_resizeable_named(num_buckets, max_buckets, &name)
     238            0 :     }
     239              : }
     240              : 
     241              : impl<'a, K, V, S: BuildHasher> HashMapAccess<'a, K, V, S>
     242              : where
     243              :     K: Clone + Hash + Eq,
     244              : {
     245              :     /// Hash a key using the map's hasher.
     246              :     #[inline]
     247       401893 :     fn get_hash_value(&self, key: &K) -> u64 {
     248       401893 :         self.hasher.hash_one(key)
     249       401893 :     }
     250              : 
     251       290557 :     fn entry_with_hash(&self, key: K, hash: u64) -> Entry<'a, '_, K, V> {
     252       290557 :         let mut map = unsafe { self.shared_ptr.as_ref() }.unwrap().write();
     253       290557 :         let dict_pos = hash as usize % map.dictionary.len();
     254       290557 :         let first = map.dictionary[dict_pos];
     255       290557 :         if first == INVALID_POS {
     256              :             // no existing entry
     257       187427 :             return Entry::Vacant(VacantEntry {
     258       187427 :                 map,
     259       187427 :                 key,
     260       187427 :                 dict_pos: dict_pos as u32,
     261       187427 :             });
     262       103130 :         }
     263              : 
     264       103130 :         let mut prev_pos = PrevPos::First(dict_pos as u32);
     265       103130 :         let mut next = first;
     266              :         loop {
     267       113199 :             let bucket = &mut map.buckets[next as usize];
     268       113199 :             let (bucket_key, _bucket_value) = bucket.inner.as_mut().expect("entry is in use");
     269       113199 :             if *bucket_key == key {
     270              :                 // found existing entry
     271        87971 :                 return Entry::Occupied(OccupiedEntry {
     272        87971 :                     map,
     273        87971 :                     _key: key,
     274        87971 :                     prev_pos,
     275        87971 :                     bucket_pos: next,
     276        87971 :                 });
     277        25228 :             }
     278              : 
     279        25228 :             if bucket.next == INVALID_POS {
     280              :                 // No existing entry
     281        15159 :                 return Entry::Vacant(VacantEntry {
     282        15159 :                     map,
     283        15159 :                     key,
     284        15159 :                     dict_pos: dict_pos as u32,
     285        15159 :                 });
     286        10069 :             }
     287        10069 :             prev_pos = PrevPos::Chained(next);
     288        10069 :             next = bucket.next;
     289              :         }
     290       290557 :     }
     291              : 
     292              :     /// Get a reference to the corresponding value for a key.
     293       111088 :     pub fn get<'e>(&'e self, key: &K) -> Option<ValueReadGuard<'e, V>> {
     294       111088 :         let hash = self.get_hash_value(key);
     295       111088 :         let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
     296       111088 :         RwLockReadGuard::try_map(map, |m| m.get_with_hash(key, hash)).ok()
     297       111088 :     }
     298              : 
     299              :     /// Get a reference to the entry containing a key.
     300              :     ///
     301              :     /// NB: THis takes a write lock as there's no way to distinguish whether the intention
     302              :     /// is to use the entry for reading or for writing in advance.
     303       289359 :     pub fn entry(&self, key: K) -> Entry<'a, '_, K, V> {
     304       289359 :         let hash = self.get_hash_value(&key);
     305       289359 :         self.entry_with_hash(key, hash)
     306       289359 :     }
     307              : 
     308              :     /// Remove a key given its hash. Returns the associated value if it existed.
     309          546 :     pub fn remove(&self, key: &K) -> Option<V> {
     310          546 :         let hash = self.get_hash_value(key);
     311          546 :         match self.entry_with_hash(key.clone(), hash) {
     312          546 :             Entry::Occupied(e) => Some(e.remove()),
     313            0 :             Entry::Vacant(_) => None,
     314              :         }
     315          546 :     }
     316              : 
     317              :     /// Insert/update a key. Returns the previous associated value if it existed.
     318              :     ///
     319              :     /// # Errors
     320              :     /// Will return [`core::FullError`] if there is no more space left in the map.
     321          652 :     pub fn insert(&self, key: K, value: V) -> Result<Option<V>, core::FullError> {
     322          652 :         let hash = self.get_hash_value(&key);
     323          652 :         match self.entry_with_hash(key.clone(), hash) {
     324            0 :             Entry::Occupied(mut e) => Ok(Some(e.insert(value))),
     325          652 :             Entry::Vacant(e) => {
     326          652 :                 _ = e.insert(value)?;
     327          651 :                 Ok(None)
     328              :             }
     329              :         }
     330          652 :     }
     331              : 
     332              :     /// Optionally return the entry for a bucket at a given index if it exists.
     333              :     ///
     334              :     /// Has more overhead than one would intuitively expect: performs both a clone of the key
     335              :     /// due to the [`OccupiedEntry`] type owning the key and also a hash of the key in order
     336              :     /// to enable repairing the hash chain if the entry is removed.
     337         3651 :     pub fn entry_at_bucket(&self, pos: usize) -> Option<OccupiedEntry<'a, '_, K, V>> {
     338         3651 :         let map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
     339         3651 :         if pos >= map.buckets.len() {
     340            0 :             return None;
     341         3651 :         }
     342              : 
     343         3651 :         let entry = map.buckets[pos].inner.as_ref();
     344         3651 :         match entry {
     345          248 :             Some((key, _)) => Some(OccupiedEntry {
     346          248 :                 _key: key.clone(),
     347          248 :                 bucket_pos: pos as u32,
     348          248 :                 prev_pos: entry::PrevPos::Unknown(self.get_hash_value(key)),
     349          248 :                 map,
     350          248 :             }),
     351         3403 :             _ => None,
     352              :         }
     353         3651 :     }
     354              : 
     355              :     /// Returns the number of buckets in the table.
     356            4 :     pub fn get_num_buckets(&self) -> usize {
     357            4 :         let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
     358            4 :         map.get_num_buckets()
     359            4 :     }
     360              : 
     361              :     /// Return the key and value stored in bucket with given index. This can be used to
     362              :     /// iterate through the hash map.
     363              :     // TODO: An Iterator might be nicer. The communicator's clock algorithm needs to
     364              :     // _slowly_ iterate through all buckets with its clock hand,  without holding a lock.
     365              :     // If we switch to an Iterator, it must not hold the lock.
     366          101 :     pub fn get_at_bucket(&self, pos: usize) -> Option<ValueReadGuard<(K, V)>> {
     367          101 :         let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
     368          101 :         if pos >= map.buckets.len() {
     369            0 :             return None;
     370          101 :         }
     371          101 :         RwLockReadGuard::try_map(map, |m| m.buckets[pos].inner.as_ref()).ok()
     372          101 :     }
     373              : 
     374              :     /// Returns the index of the bucket a given value corresponds to.
     375           43 :     pub fn get_bucket_for_value(&self, val_ptr: *const V) -> usize {
     376           43 :         let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
     377              : 
     378           43 :         let origin = map.buckets.as_ptr();
     379           43 :         let idx = (val_ptr as usize - origin as usize) / size_of::<Bucket<K, V>>();
     380           43 :         assert!(idx < map.buckets.len());
     381              : 
     382           43 :         idx
     383           43 :     }
     384              : 
     385              :     /// Returns the number of occupied buckets in the table.
     386           14 :     pub fn get_num_buckets_in_use(&self) -> usize {
     387           14 :         let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
     388           14 :         map.buckets_in_use as usize
     389           14 :     }
     390              : 
     391              :     /// Clears all entries in a table. Does not reset any shrinking operations.
     392            2 :     pub fn clear(&self) {
     393            2 :         let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
     394            2 :         map.clear();
     395            2 :     }
     396              : 
     397              :     /// Perform an in-place rehash of some region (0..`rehash_buckets`) of the table and reset
     398              :     /// the `buckets` and `dictionary` slices to be as long as `num_buckets`. Resets the freelist
     399              :     /// in the process.
     400            9 :     fn rehash_dict(
     401            9 :         &self,
     402            9 :         inner: &mut CoreHashMap<'a, K, V>,
     403            9 :         buckets_ptr: *mut core::Bucket<K, V>,
     404            9 :         end_ptr: *mut u8,
     405            9 :         num_buckets: u32,
     406            9 :         rehash_buckets: u32,
     407            9 :     ) {
     408            9 :         inner.free_head = INVALID_POS;
     409              : 
     410              :         let buckets;
     411              :         let dictionary;
     412            9 :         unsafe {
     413            9 :             let buckets_end_ptr = buckets_ptr.add(num_buckets as usize);
     414            9 :             let dictionary_ptr: *mut u32 = buckets_end_ptr
     415            9 :                 .byte_add(buckets_end_ptr.align_offset(align_of::<u32>()))
     416            9 :                 .cast();
     417            9 :             let dictionary_size: usize =
     418            9 :                 end_ptr.byte_offset_from(buckets_end_ptr) as usize / size_of::<u32>();
     419            9 : 
     420            9 :             buckets = std::slice::from_raw_parts_mut(buckets_ptr, num_buckets as usize);
     421            9 :             dictionary = std::slice::from_raw_parts_mut(dictionary_ptr, dictionary_size);
     422            9 :         }
     423        28802 :         for e in dictionary.iter_mut() {
     424        28802 :             *e = INVALID_POS;
     425        28802 :         }
     426              : 
     427         4900 :         for (i, bucket) in buckets.iter_mut().enumerate().take(rehash_buckets as usize) {
     428         4900 :             if bucket.inner.is_none() {
     429         2157 :                 bucket.next = inner.free_head;
     430         2157 :                 inner.free_head = i as u32;
     431         2157 :                 continue;
     432         2743 :             }
     433              : 
     434         2743 :             let hash = self.hasher.hash_one(&bucket.inner.as_ref().unwrap().0);
     435         2743 :             let pos: usize = (hash % dictionary.len() as u64) as usize;
     436         2743 :             bucket.next = dictionary[pos];
     437         2743 :             dictionary[pos] = i as u32;
     438              :         }
     439              : 
     440            9 :         inner.dictionary = dictionary;
     441            9 :         inner.buckets = buckets;
     442            9 :     }
     443              : 
     444              :     /// Rehash the map without growing or shrinking.
     445            1 :     pub fn shuffle(&self) {
     446            1 :         let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
     447            1 :         let num_buckets = map.get_num_buckets() as u32;
     448            1 :         let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
     449            1 :         let end_ptr: *mut u8 = unsafe { self.shared_ptr.byte_add(size_bytes).cast() };
     450            1 :         let buckets_ptr = map.buckets.as_mut_ptr();
     451            1 :         self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets);
     452            1 :     }
     453              : 
     454              :     /// Grow the number of buckets within the table.
     455              :     ///
     456              :     /// 1. Grows the underlying shared memory area
     457              :     /// 2. Initializes new buckets and overwrites the current dictionary
     458              :     /// 3. Rehashes the dictionary
     459              :     ///
     460              :     /// # Panics
     461              :     /// Panics if called on a map initialized with [`HashMapInit::with_fixed`].
     462              :     ///
     463              :     /// # Errors
     464              :     /// Returns an [`shmem::Error`] if any errors occur resizing the memory region.
     465            5 :     pub fn grow(&self, num_buckets: u32) -> Result<(), shmem::Error> {
     466            5 :         let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
     467            5 :         let old_num_buckets = map.buckets.len() as u32;
     468              : 
     469            5 :         assert!(
     470            5 :             num_buckets >= old_num_buckets,
     471            0 :             "grow called with a smaller number of buckets"
     472              :         );
     473            5 :         if num_buckets == old_num_buckets {
     474            0 :             return Ok(());
     475            5 :         }
     476            5 :         let shmem_handle = self
     477            5 :             .shmem_handle
     478            5 :             .as_ref()
     479            5 :             .expect("grow called on a fixed-size hash table");
     480              : 
     481            5 :         let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
     482            5 :         shmem_handle.set_size(size_bytes)?;
     483            5 :         let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) };
     484              : 
     485              :         // Initialize new buckets. The new buckets are linked to the free list.
     486              :         // NB: This overwrites the dictionary!
     487            5 :         let buckets_ptr = map.buckets.as_mut_ptr();
     488              :         unsafe {
     489        11099 :             for i in old_num_buckets..num_buckets {
     490        11099 :                 let bucket = buckets_ptr.add(i as usize);
     491        11099 :                 bucket.write(core::Bucket {
     492        11099 :                     next: if i < num_buckets - 1 {
     493        11096 :                         i + 1
     494              :                     } else {
     495            3 :                         map.free_head
     496              :                     },
     497        11099 :                     inner: None,
     498              :                 });
     499              :             }
     500              :         }
     501              : 
     502            5 :         self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, old_num_buckets);
     503            5 :         map.free_head = old_num_buckets;
     504              : 
     505            5 :         Ok(())
     506            5 :     }
     507              : 
     508              :     /// Begin a shrink, limiting all new allocations to be in buckets with index below `num_buckets`.
     509              :     ///
     510              :     /// # Panics
     511              :     /// Panics if called on a map initialized with [`HashMapInit::with_fixed`] or if `num_buckets` is
     512              :     /// greater than the number of buckets in the map.
     513            6 :     pub fn begin_shrink(&mut self, num_buckets: u32) {
     514            6 :         let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
     515            6 :         assert!(
     516            6 :             num_buckets <= map.get_num_buckets() as u32,
     517            1 :             "shrink called with a larger number of buckets"
     518              :         );
     519            5 :         _ = self
     520            5 :             .shmem_handle
     521            5 :             .as_ref()
     522            5 :             .expect("shrink called on a fixed-size hash table");
     523            5 :         map.alloc_limit = num_buckets;
     524            5 :     }
     525              : 
     526              :     /// If a shrink operation is underway, returns the target size of the map. Otherwise, returns None.
     527            9 :     pub fn shrink_goal(&self) -> Option<usize> {
     528            9 :         let map = unsafe { self.shared_ptr.as_mut() }.unwrap().read();
     529            9 :         let goal = map.alloc_limit;
     530            9 :         if goal == INVALID_POS {
     531            6 :             None
     532              :         } else {
     533            3 :             Some(goal as usize)
     534              :         }
     535            9 :     }
     536              : 
     537              :     /// Complete a shrink after caller has evicted entries, removing the unused buckets and rehashing.
     538              :     ///
     539              :     /// # Panics
     540              :     /// The following cases result in a panic:
     541              :     /// - Calling this function on a map initialized with [`HashMapInit::with_fixed`].
     542              :     /// - Calling this function on a map when no shrink operation is in progress.
     543            5 :     pub fn finish_shrink(&self) -> Result<(), HashMapShrinkError> {
     544            5 :         let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
     545            5 :         assert!(
     546            5 :             map.alloc_limit != INVALID_POS,
     547            1 :             "called finish_shrink when no shrink is in progress"
     548              :         );
     549              : 
     550            4 :         let num_buckets = map.alloc_limit;
     551              : 
     552            4 :         if map.get_num_buckets() == num_buckets as usize {
     553            0 :             return Ok(());
     554            4 :         }
     555              : 
     556            4 :         assert!(
     557            4 :             map.buckets_in_use <= num_buckets,
     558            0 :             "called finish_shrink before enough entries were removed"
     559              :         );
     560              : 
     561         3550 :         for i in (num_buckets as usize)..map.buckets.len() {
     562         3550 :             if map.buckets[i].inner.is_some() {
     563            0 :                 return Err(HashMapShrinkError::RemainingEntries(i));
     564         3550 :             }
     565              :         }
     566              : 
     567            4 :         let shmem_handle = self
     568            4 :             .shmem_handle
     569            4 :             .as_ref()
     570            4 :             .expect("shrink called on a fixed-size hash table");
     571              : 
     572            4 :         let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
     573            4 :         if let Err(e) = shmem_handle.set_size(size_bytes) {
     574            0 :             return Err(HashMapShrinkError::ResizeError(e));
     575            4 :         }
     576            4 :         let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) };
     577            4 :         let buckets_ptr = map.buckets.as_mut_ptr();
     578            4 :         self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets);
     579            4 :         map.alloc_limit = INVALID_POS;
     580              : 
     581            4 :         Ok(())
     582            4 :     }
     583              : }
        

Generated by: LCOV version 2.1-beta