LCOV - code coverage report
Current view: top level - libs/neon-shmem/src - shmem.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 91.9 % 173 159
Test Date: 2025-07-16 12:29:03 Functions: 73.7 % 19 14

            Line data    Source code
       1              : //! Dynamically resizable contiguous chunk of shared memory
       2              : 
       3              : use std::num::NonZeroUsize;
       4              : use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
       5              : use std::ptr::NonNull;
       6              : use std::sync::atomic::{AtomicUsize, Ordering};
       7              : 
       8              : use nix::errno::Errno;
       9              : use nix::sys::mman::MapFlags;
      10              : use nix::sys::mman::ProtFlags;
      11              : use nix::sys::mman::mmap as nix_mmap;
      12              : use nix::sys::mman::munmap as nix_munmap;
      13              : use nix::unistd::ftruncate as nix_ftruncate;
      14              : 
      15              : /// `ShmemHandle` represents a shared memory area that can be shared by processes over `fork()`.
      16              : /// Unlike shared memory allocated by Postgres, this area is resizable, up to `max_size` that's
      17              : /// specified at creation.
      18              : ///
      19              : /// The area is backed by an anonymous file created with `memfd_create()`. The full address space for
      20              : /// `max_size` is reserved up-front with `mmap()`, but whenever you call [`ShmemHandle::set_size`],
      21              : /// the underlying file is resized. Do not access the area beyond the current size. Currently, that
      22              : /// will cause the file to be expanded, but we might use `mprotect()` etc. to enforce that in the
      23              : /// future.
      24              : pub struct ShmemHandle {
      25              :     /// memfd file descriptor
      26              :     fd: OwnedFd,
      27              : 
      28              :     max_size: usize,
      29              : 
      30              :     // Pointer to the beginning of the shared memory area. The header is stored there.
      31              :     shared_ptr: NonNull<SharedStruct>,
      32              : 
      33              :     // Pointer to the beginning of the user data
      34              :     pub data_ptr: NonNull<u8>,
      35              : }
      36              : 
      37              : /// This is stored at the beginning in the shared memory area.
      38              : struct SharedStruct {
      39              :     max_size: usize,
      40              : 
      41              :     /// Current size of the backing file. The high-order bit is used for the [`RESIZE_IN_PROGRESS`] flag.
      42              :     current_size: AtomicUsize,
      43              : }
      44              : 
      45              : const RESIZE_IN_PROGRESS: usize = 1 << 63;
      46              : 
      47              : const HEADER_SIZE: usize = std::mem::size_of::<SharedStruct>();
      48              : 
      49              : /// Error type returned by the [`ShmemHandle`] functions.
      50              : #[derive(thiserror::Error, Debug)]
      51              : #[error("{msg}: {errno}")]
      52              : pub struct Error {
      53              :     pub msg: String,
      54              :     pub errno: Errno,
      55              : }
      56              : 
      57              : impl Error {
      58            0 :     fn new(msg: &str, errno: Errno) -> Self {
      59            0 :         Self {
      60            0 :             msg: msg.to_string(),
      61            0 :             errno,
      62            0 :         }
      63            0 :     }
      64              : }
      65              : 
      66              : impl ShmemHandle {
      67              :     /// Create a new shared memory area. To communicate between processes, the processes need to be
      68              :     /// `fork()`'d after calling this, so that the `ShmemHandle` is inherited by all processes.
      69              :     ///
      70              :     /// If the `ShmemHandle` is dropped, the memory is unmapped from the current process. Other
      71              :     /// processes can continue using it, however.
      72            3 :     pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result<Self, Error> {
      73              :         // create the backing anonymous file.
      74            3 :         let fd = create_backing_file(name)?;
      75              : 
      76            3 :         Self::new_with_fd(fd, initial_size, max_size)
      77            3 :     }
      78              : 
      79            3 :     fn new_with_fd(fd: OwnedFd, initial_size: usize, max_size: usize) -> Result<Self, Error> {
      80              :         // We reserve the high-order bit for the `RESIZE_IN_PROGRESS` flag, and the actual size
      81              :         // is a little larger than this because of the SharedStruct header. Make the upper limit
      82              :         // somewhat smaller than that, because with anything close to that, you'll run out of
      83              :         // memory anyway.
      84            3 :         assert!(max_size < 1 << 48, "max size {max_size} too large");
      85              : 
      86            3 :         assert!(
      87            3 :             initial_size <= max_size,
      88            0 :             "initial size {initial_size} larger than max size {max_size}"
      89              :         );
      90              : 
      91              :         // The actual initial / max size is the one given by the caller, plus the size of
      92              :         // 'SharedStruct'.
      93            3 :         let initial_size = HEADER_SIZE + initial_size;
      94            3 :         let max_size = NonZeroUsize::new(HEADER_SIZE + max_size).unwrap();
      95              : 
      96              :         // Reserve address space for it with mmap
      97              :         //
      98              :         // TODO: Use MAP_HUGETLB if possible
      99            3 :         let start_ptr = unsafe {
     100            3 :             nix_mmap(
     101            3 :                 None,
     102            3 :                 max_size,
     103            3 :                 ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
     104              :                 MapFlags::MAP_SHARED,
     105            3 :                 &fd,
     106              :                 0,
     107              :             )
     108              :         }
     109            3 :         .map_err(|e| Error::new("mmap failed", e))?;
     110              : 
     111              :         // Reserve space for the initial size
     112            3 :         enlarge_file(fd.as_fd(), initial_size as u64)?;
     113              : 
     114              :         // Initialize the header
     115            3 :         let shared: NonNull<SharedStruct> = start_ptr.cast();
     116            3 :         unsafe {
     117            3 :             shared.write(SharedStruct {
     118            3 :                 max_size: max_size.into(),
     119            3 :                 current_size: AtomicUsize::new(initial_size),
     120            3 :             });
     121            3 :         }
     122              : 
     123              :         // The user data begins after the header
     124            3 :         let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) };
     125              : 
     126            3 :         Ok(Self {
     127            3 :             fd,
     128            3 :             max_size: max_size.into(),
     129            3 :             shared_ptr: shared,
     130            3 :             data_ptr,
     131            3 :         })
     132            3 :     }
     133              : 
     134              :     // return reference to the header
     135           10 :     fn shared(&self) -> &SharedStruct {
     136           10 :         unsafe { self.shared_ptr.as_ref() }
     137           10 :     }
     138              : 
     139              :     /// Resize the shared memory area. `new_size` must not be larger than the `max_size` specified
     140              :     /// when creating the area.
     141              :     ///
     142              :     /// This may only be called from one process/thread concurrently. We detect that case
     143              :     /// and return an [`shmem::Error`](Error).
     144            6 :     pub fn set_size(&self, new_size: usize) -> Result<(), Error> {
     145            6 :         let new_size = new_size + HEADER_SIZE;
     146            6 :         let shared = self.shared();
     147              : 
     148            6 :         assert!(
     149            6 :             new_size <= self.max_size,
     150            0 :             "new size ({new_size}) is greater than max size ({})",
     151              :             self.max_size
     152              :         );
     153              : 
     154            6 :         assert_eq!(self.max_size, shared.max_size);
     155              : 
     156              :         // Lock the area by setting the bit in `current_size`
     157              :         //
     158              :         // Ordering::Relaxed would probably be sufficient here, as we don't access any other memory
     159              :         // and the `posix_fallocate`/`ftruncate` call is surely a synchronization point anyway. But
     160              :         // since this is not performance-critical, better safe than sorry.
     161            6 :         let mut old_size = shared.current_size.load(Ordering::Acquire);
     162              :         loop {
     163            6 :             if (old_size & RESIZE_IN_PROGRESS) != 0 {
     164            0 :                 return Err(Error::new(
     165            0 :                     "concurrent resize detected",
     166            0 :                     Errno::UnknownErrno,
     167            0 :                 ));
     168            6 :             }
     169            6 :             match shared.current_size.compare_exchange(
     170            6 :                 old_size,
     171            6 :                 new_size,
     172            6 :                 Ordering::Acquire,
     173            6 :                 Ordering::Relaxed,
     174            6 :             ) {
     175            6 :                 Ok(_) => break,
     176            0 :                 Err(x) => old_size = x,
     177              :             }
     178              :         }
     179              : 
     180              :         // Ok, we got the lock.
     181              :         //
     182              :         // NB: If anything goes wrong, we *must* clear the bit!
     183            6 :         let result = {
     184              :             use std::cmp::Ordering::{Equal, Greater, Less};
     185            6 :             match new_size.cmp(&old_size) {
     186            1 :                 Less => nix_ftruncate(&self.fd, new_size as i64)
     187            1 :                     .map_err(|e| Error::new("could not shrink shmem segment, ftruncate failed", e)),
     188            0 :                 Equal => Ok(()),
     189            5 :                 Greater => enlarge_file(self.fd.as_fd(), new_size as u64),
     190              :             }
     191              :         };
     192              : 
     193              :         // Unlock
     194            6 :         shared.current_size.store(
     195            6 :             if result.is_ok() { new_size } else { old_size },
     196            6 :             Ordering::Release,
     197              :         );
     198              : 
     199            6 :         result
     200            6 :     }
     201              : 
     202              :     /// Returns the current user-visible size of the shared memory segment.
     203              :     ///
     204              :     /// NOTE: a concurrent [`ShmemHandle::set_size()`] call can change the size at any time.
     205              :     /// It is the caller's responsibility not to access the area beyond the current size.
     206            4 :     pub fn current_size(&self) -> usize {
     207            4 :         let total_current_size =
     208            4 :             self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS;
     209            4 :         total_current_size - HEADER_SIZE
     210            4 :     }
     211              : }
     212              : 
     213              : impl Drop for ShmemHandle {
     214            3 :     fn drop(&mut self) {
     215              :         // SAFETY: The pointer was obtained from mmap() with the given size.
     216              :         // We unmap the entire region.
     217            3 :         let _ = unsafe { nix_munmap(self.shared_ptr.cast(), self.max_size) };
     218              :         // The fd is dropped automatically by OwnedFd.
     219            3 :     }
     220              : }
     221              : 
     222              : /// Create a "backing file" for the shared memory area. On Linux, use `memfd_create()`, to create an
     223              : /// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for
     224              : /// development and testing, but in production we want the file to stay in memory.
     225              : ///
     226              : /// Disable unused variables warnings because `name` is unused in the macos path.
     227              : #[allow(unused_variables)]
     228            3 : fn create_backing_file(name: &str) -> Result<OwnedFd, Error> {
     229              :     #[cfg(not(target_os = "macos"))]
     230              :     {
     231            3 :         nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty())
     232            3 :             .map_err(|e| Error::new("memfd_create failed", e))
     233              :     }
     234              :     #[cfg(target_os = "macos")]
     235              :     {
     236              :         let file = tempfile::tempfile().map_err(|e| {
     237              :             Error::new(
     238              :                 "could not create temporary file to back shmem area",
     239              :                 nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)),
     240              :             )
     241              :         })?;
     242              :         Ok(OwnedFd::from(file))
     243              :     }
     244            3 : }
     245              : 
     246            8 : fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> {
     247              :     // Use posix_fallocate() to enlarge the file. It reserves the space correctly, so that
     248              :     // we don't get a segfault later when trying to actually use it.
     249              :     #[cfg(not(target_os = "macos"))]
     250              :     {
     251            8 :         nix::fcntl::posix_fallocate(fd, 0, size as i64)
     252            8 :             .map_err(|e| Error::new("could not grow shmem segment, posix_fallocate failed", e))
     253              :     }
     254              :     // As a fallback on macos, which doesn't have posix_fallocate, use plain 'fallocate'
     255              :     #[cfg(target_os = "macos")]
     256              :     {
     257              :         nix::unistd::ftruncate(fd, size as i64)
     258              :             .map_err(|e| Error::new("could not grow shmem segment, ftruncate failed", e))
     259              :     }
     260            8 : }
     261              : 
     262              : #[cfg(test)]
     263              : mod tests {
     264              :     use super::*;
     265              : 
     266              :     use nix::unistd::ForkResult;
     267              :     use std::ops::Range;
     268              : 
     269              :     /// check that all bytes in given range have the expected value.
     270           13 :     fn assert_range(ptr: *const u8, expected: u8, range: Range<usize>) {
     271        26013 :         for i in range {
     272        26000 :             let b = unsafe { *(ptr.add(i)) };
     273        26000 :             assert_eq!(expected, b, "unexpected byte at offset {i}");
     274              :         }
     275           13 :     }
     276              : 
     277              :     /// Write 'b' to all bytes in the given range
     278            5 :     fn write_range(ptr: *mut u8, b: u8, range: Range<usize>) {
     279            5 :         unsafe { std::ptr::write_bytes(ptr.add(range.start), b, range.end - range.start) };
     280            5 :     }
     281              : 
     282              :     // simple single-process test of growing and shrinking
     283              :     #[test]
     284            1 :     fn test_shmem_resize() -> Result<(), Error> {
     285            1 :         let max_size = 1024 * 1024;
     286            1 :         let init_struct = ShmemHandle::new("test_shmem_resize", 0, max_size)?;
     287              : 
     288            1 :         assert_eq!(init_struct.current_size(), 0);
     289              : 
     290              :         // Initial grow
     291            1 :         let size1 = 10000;
     292            1 :         init_struct.set_size(size1).unwrap();
     293            1 :         assert_eq!(init_struct.current_size(), size1);
     294              : 
     295              :         // Write some data
     296            1 :         let data_ptr = init_struct.data_ptr.as_ptr();
     297            1 :         write_range(data_ptr, 0xAA, 0..size1);
     298            1 :         assert_range(data_ptr, 0xAA, 0..size1);
     299              : 
     300              :         // Shrink
     301            1 :         let size2 = 5000;
     302            1 :         init_struct.set_size(size2).unwrap();
     303            1 :         assert_eq!(init_struct.current_size(), size2);
     304              : 
     305              :         // Grow again
     306            1 :         let size3 = 20000;
     307            1 :         init_struct.set_size(size3).unwrap();
     308            1 :         assert_eq!(init_struct.current_size(), size3);
     309              : 
     310              :         // Try to read it. The area that was shrunk and grown again should read as all zeros now
     311            1 :         assert_range(data_ptr, 0xAA, 0..5000);
     312            1 :         assert_range(data_ptr, 0, 5000..size1);
     313              : 
     314              :         // Try to grow beyond max_size
     315              :         //let size4 = max_size + 1;
     316              :         //assert!(init_struct.set_size(size4).is_err());
     317              : 
     318              :         // Dropping init_struct should unmap the memory
     319            1 :         drop(init_struct);
     320              : 
     321            1 :         Ok(())
     322            1 :     }
     323              : 
     324              :     /// This is used in tests to coordinate between test processes. It's like `std::sync::Barrier`,
     325              :     /// but is stored in the shared memory area and works across processes. It's implemented by
     326              :     /// polling, because e.g. standard rust mutexes are not guaranteed to work across processes.
     327              :     struct SimpleBarrier {
     328              :         num_procs: usize,
     329              :         count: AtomicUsize,
     330              :     }
     331              : 
     332              :     impl SimpleBarrier {
     333            2 :         unsafe fn init(ptr: *mut SimpleBarrier, num_procs: usize) {
     334              :             unsafe {
     335            2 :                 *ptr = SimpleBarrier {
     336            2 :                     num_procs,
     337            2 :                     count: AtomicUsize::new(0),
     338            2 :                 }
     339              :             }
     340            2 :         }
     341              : 
     342            6 :         pub fn wait(&self) {
     343            6 :             let old = self.count.fetch_add(1, Ordering::Relaxed);
     344              : 
     345            6 :             let generation = old / self.num_procs;
     346              : 
     347            6 :             let mut current = old + 1;
     348            9 :             while current < (generation + 1) * self.num_procs {
     349            3 :                 std::thread::sleep(std::time::Duration::from_millis(10));
     350            3 :                 current = self.count.load(Ordering::Relaxed);
     351            3 :             }
     352            6 :         }
     353              :     }
     354              : 
     355              :     #[test]
     356            2 :     fn test_multi_process() {
     357              :         // Initialize
     358            2 :         let max_size = 1_000_000_000_000;
     359            2 :         let init_struct = ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
     360            2 :         let ptr = init_struct.data_ptr.as_ptr();
     361              : 
     362              :         // Store the SimpleBarrier in the first 1k of the area.
     363            2 :         init_struct.set_size(10000).unwrap();
     364            2 :         let barrier_ptr: *mut SimpleBarrier = unsafe {
     365            2 :             ptr.add(ptr.align_offset(std::mem::align_of::<SimpleBarrier>()))
     366            2 :                 .cast()
     367              :         };
     368            2 :         unsafe { SimpleBarrier::init(barrier_ptr, 2) };
     369            2 :         let barrier = unsafe { barrier_ptr.as_ref().unwrap() };
     370              : 
     371              :         // Fork another test process. The code after this runs in both processes concurrently.
     372            2 :         let fork_result = unsafe { nix::unistd::fork().unwrap() };
     373              : 
     374              :         // In the parent, fill bytes between 1000..2000. In the child, between 2000..3000
     375            2 :         if fork_result.is_parent() {
     376            1 :             write_range(ptr, 0xAA, 1000..2000);
     377            1 :         } else {
     378            1 :             write_range(ptr, 0xBB, 2000..3000);
     379            1 :         }
     380            2 :         barrier.wait();
     381              :         // Verify the contents. (in both processes)
     382            2 :         assert_range(ptr, 0xAA, 1000..2000);
     383            2 :         assert_range(ptr, 0xBB, 2000..3000);
     384              : 
     385              :         // Grow, from the child this time
     386            2 :         let size = 10_000_000;
     387            2 :         if !fork_result.is_parent() {
     388            1 :             init_struct.set_size(size).unwrap();
     389            1 :         }
     390            2 :         barrier.wait();
     391              : 
     392              :         // make some writes at the end
     393            2 :         if fork_result.is_parent() {
     394            1 :             write_range(ptr, 0xAA, (size - 10)..size);
     395            1 :         } else {
     396            1 :             write_range(ptr, 0xBB, (size - 20)..(size - 10));
     397            1 :         }
     398            2 :         barrier.wait();
     399              : 
     400              :         // Verify the contents. (This runs in both processes)
     401            2 :         assert_range(ptr, 0, (size - 1000)..(size - 20));
     402            2 :         assert_range(ptr, 0xBB, (size - 20)..(size - 10));
     403            2 :         assert_range(ptr, 0xAA, (size - 10)..size);
     404              : 
     405            2 :         if let ForkResult::Parent { child } = fork_result {
     406            1 :             nix::sys::wait::waitpid(child, None).unwrap();
     407            1 :         }
     408            2 :     }
     409              : }
        

Generated by: LCOV version 2.1-beta