LCOV - code coverage report
Current view: top level - libs/neon-shmem/src - lib.rs (source / functions) Coverage Total Hit
Test: 15f04989d2faf4ce76cecb56042184aca56ebae6.info Lines: 87.9 % 182 160
Test Date: 2025-07-14 11:50:36 Functions: 73.7 % 19 14

            Line data    Source code
       1              : //! Shared memory utilities for neon communicator
       2              : 
       3              : use std::num::NonZeroUsize;
       4              : use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
       5              : use std::ptr::NonNull;
       6              : use std::sync::atomic::{AtomicUsize, Ordering};
       7              : 
       8              : use nix::errno::Errno;
       9              : use nix::sys::mman::MapFlags;
      10              : use nix::sys::mman::ProtFlags;
      11              : use nix::sys::mman::mmap as nix_mmap;
      12              : use nix::sys::mman::munmap as nix_munmap;
      13              : use nix::unistd::ftruncate as nix_ftruncate;
      14              : 
      15              : /// ShmemHandle represents a shared memory area that can be shared by processes over fork().
      16              : /// Unlike shared memory allocated by Postgres, this area is resizable, up to 'max_size' that's
      17              : /// specified at creation.
      18              : ///
      19              : /// The area is backed by an anonymous file created with memfd_create(). The full address space for
      20              : /// 'max_size' is reserved up-front with mmap(), but whenever you call [`ShmemHandle::set_size`],
      21              : /// the underlying file is resized. Do not access the area beyond the current size. Currently, that
      22              : /// will cause the file to be expanded, but we might use mprotect() etc. to enforce that in the
      23              : /// future.
      24              : pub struct ShmemHandle {
      25              :     /// memfd file descriptor
      26              :     fd: OwnedFd,
      27              : 
      28              :     max_size: usize,
      29              : 
      30              :     // Pointer to the beginning of the shared memory area. The header is stored there.
      31              :     shared_ptr: NonNull<SharedStruct>,
      32              : 
      33              :     // Pointer to the beginning of the user data
      34              :     pub data_ptr: NonNull<u8>,
      35              : }
      36              : 
      37              : /// This is stored at the beginning in the shared memory area.
      38              : struct SharedStruct {
      39              :     max_size: usize,
      40              : 
      41              :     /// Current size of the backing file. The high-order bit is used for the RESIZE_IN_PROGRESS flag
      42              :     current_size: AtomicUsize,
      43              : }
      44              : 
      45              : const RESIZE_IN_PROGRESS: usize = 1 << 63;
      46              : 
      47              : const HEADER_SIZE: usize = std::mem::size_of::<SharedStruct>();
      48              : 
      49              : /// Error type returned by the ShmemHandle functions.
      50              : #[derive(thiserror::Error, Debug)]
      51              : #[error("{msg}: {errno}")]
      52              : pub struct Error {
      53              :     pub msg: String,
      54              :     pub errno: Errno,
      55              : }
      56              : 
      57              : impl Error {
      58            0 :     fn new(msg: &str, errno: Errno) -> Error {
      59            0 :         Error {
      60            0 :             msg: msg.to_string(),
      61            0 :             errno,
      62            0 :         }
      63            0 :     }
      64              : }
      65              : 
      66              : impl ShmemHandle {
      67              :     /// Create a new shared memory area. To communicate between processes, the processes need to be
      68              :     /// fork()'d after calling this, so that the ShmemHandle is inherited by all processes.
      69              :     ///
      70              :     /// If the ShmemHandle is dropped, the memory is unmapped from the current process. Other
      71              :     /// processes can continue using it, however.
      72            3 :     pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result<ShmemHandle, Error> {
      73              :         // create the backing anonymous file.
      74            3 :         let fd = create_backing_file(name)?;
      75              : 
      76            3 :         Self::new_with_fd(fd, initial_size, max_size)
      77            3 :     }
      78              : 
      79            3 :     fn new_with_fd(
      80            3 :         fd: OwnedFd,
      81            3 :         initial_size: usize,
      82            3 :         max_size: usize,
      83            3 :     ) -> Result<ShmemHandle, Error> {
      84              :         // We reserve the high-order bit for the RESIZE_IN_PROGRESS flag, and the actual size
      85              :         // is a little larger than this because of the SharedStruct header. Make the upper limit
      86              :         // somewhat smaller than that, because with anything close to that, you'll run out of
      87              :         // memory anyway.
      88            3 :         if max_size >= 1 << 48 {
      89            0 :             panic!("max size {max_size} too large");
      90            3 :         }
      91            3 :         if initial_size > max_size {
      92            0 :             panic!("initial size {initial_size} larger than max size {max_size}");
      93            3 :         }
      94              : 
      95              :         // The actual initial / max size is the one given by the caller, plus the size of
      96              :         // 'SharedStruct'.
      97            3 :         let initial_size = HEADER_SIZE + initial_size;
      98            3 :         let max_size = NonZeroUsize::new(HEADER_SIZE + max_size).unwrap();
      99              : 
     100              :         // Reserve address space for it with mmap
     101              :         //
     102              :         // TODO: Use MAP_HUGETLB if possible
     103            3 :         let start_ptr = unsafe {
     104            3 :             nix_mmap(
     105            3 :                 None,
     106            3 :                 max_size,
     107            3 :                 ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
     108              :                 MapFlags::MAP_SHARED,
     109            3 :                 &fd,
     110              :                 0,
     111              :             )
     112              :         }
     113            3 :         .map_err(|e| Error::new("mmap failed: {e}", e))?;
     114              : 
     115              :         // Reserve space for the initial size
     116            3 :         enlarge_file(fd.as_fd(), initial_size as u64)?;
     117              : 
     118              :         // Initialize the header
     119            3 :         let shared: NonNull<SharedStruct> = start_ptr.cast();
     120              :         unsafe {
     121            3 :             shared.write(SharedStruct {
     122            3 :                 max_size: max_size.into(),
     123            3 :                 current_size: AtomicUsize::new(initial_size),
     124            3 :             })
     125              :         };
     126              : 
     127              :         // The user data begins after the header
     128            3 :         let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) };
     129              : 
     130            3 :         Ok(ShmemHandle {
     131            3 :             fd,
     132            3 :             max_size: max_size.into(),
     133            3 :             shared_ptr: shared,
     134            3 :             data_ptr,
     135            3 :         })
     136            3 :     }
     137              : 
     138              :     // return reference to the header
     139           10 :     fn shared(&self) -> &SharedStruct {
     140           10 :         unsafe { self.shared_ptr.as_ref() }
     141           10 :     }
     142              : 
     143              :     /// Resize the shared memory area. 'new_size' must not be larger than the 'max_size' specified
     144              :     /// when creating the area.
     145              :     ///
     146              :     /// This may only be called from one process/thread concurrently. We detect that case
     147              :     /// and return an Error.
     148            6 :     pub fn set_size(&self, new_size: usize) -> Result<(), Error> {
     149            6 :         let new_size = new_size + HEADER_SIZE;
     150            6 :         let shared = self.shared();
     151              : 
     152            6 :         if new_size > self.max_size {
     153            0 :             panic!(
     154            0 :                 "new size ({} is greater than max size ({})",
     155              :                 new_size, self.max_size
     156              :             );
     157            6 :         }
     158            6 :         assert_eq!(self.max_size, shared.max_size);
     159              : 
     160              :         // Lock the area by setting the bit in 'current_size'
     161              :         //
     162              :         // Ordering::Relaxed would probably be sufficient here, as we don't access any other memory
     163              :         // and the posix_fallocate/ftruncate call is surely a synchronization point anyway. But
     164              :         // since this is not performance-critical, better safe than sorry .
     165            6 :         let mut old_size = shared.current_size.load(Ordering::Acquire);
     166              :         loop {
     167            6 :             if (old_size & RESIZE_IN_PROGRESS) != 0 {
     168            0 :                 return Err(Error::new(
     169            0 :                     "concurrent resize detected",
     170            0 :                     Errno::UnknownErrno,
     171            0 :                 ));
     172            6 :             }
     173            6 :             match shared.current_size.compare_exchange(
     174            6 :                 old_size,
     175            6 :                 new_size,
     176            6 :                 Ordering::Acquire,
     177            6 :                 Ordering::Relaxed,
     178            6 :             ) {
     179            6 :                 Ok(_) => break,
     180            0 :                 Err(x) => old_size = x,
     181              :             }
     182              :         }
     183              : 
     184              :         // Ok, we got the lock.
     185              :         //
     186              :         // NB: If anything goes wrong, we *must* clear the bit!
     187            6 :         let result = {
     188              :             use std::cmp::Ordering::{Equal, Greater, Less};
     189            6 :             match new_size.cmp(&old_size) {
     190            1 :                 Less => nix_ftruncate(&self.fd, new_size as i64).map_err(|e| {
     191            0 :                     Error::new("could not shrink shmem segment, ftruncate failed: {e}", e)
     192            0 :                 }),
     193            0 :                 Equal => Ok(()),
     194            5 :                 Greater => enlarge_file(self.fd.as_fd(), new_size as u64),
     195              :             }
     196              :         };
     197              : 
     198              :         // Unlock
     199            6 :         shared.current_size.store(
     200            6 :             if result.is_ok() { new_size } else { old_size },
     201            6 :             Ordering::Release,
     202              :         );
     203              : 
     204            6 :         result
     205            6 :     }
     206              : 
     207              :     /// Returns the current user-visible size of the shared memory segment.
     208              :     ///
     209              :     /// NOTE: a concurrent set_size() call can change the size at any time. It is the caller's
     210              :     /// responsibility not to access the area beyond the current size.
     211            4 :     pub fn current_size(&self) -> usize {
     212            4 :         let total_current_size =
     213            4 :             self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS;
     214            4 :         total_current_size - HEADER_SIZE
     215            4 :     }
     216              : }
     217              : 
     218              : impl Drop for ShmemHandle {
     219            3 :     fn drop(&mut self) {
     220              :         // SAFETY: The pointer was obtained from mmap() with the given size.
     221              :         // We unmap the entire region.
     222            3 :         let _ = unsafe { nix_munmap(self.shared_ptr.cast(), self.max_size) };
     223              :         // The fd is dropped automatically by OwnedFd.
     224            3 :     }
     225              : }
     226              : 
     227              : /// Create a "backing file" for the shared memory area. On Linux, use memfd_create(), to create an
     228              : /// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for
     229              : /// development and testing, but in production we want the file to stay in memory.
     230              : ///
     231              : /// disable 'unused_variables' warnings, because in the macos path, 'name' is unused.
     232              : #[allow(unused_variables)]
     233            3 : fn create_backing_file(name: &str) -> Result<OwnedFd, Error> {
     234              :     #[cfg(not(target_os = "macos"))]
     235              :     {
     236            3 :         nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty())
     237            3 :             .map_err(|e| Error::new("memfd_create failed: {e}", e))
     238              :     }
     239              :     #[cfg(target_os = "macos")]
     240              :     {
     241              :         let file = tempfile::tempfile().map_err(|e| {
     242              :             Error::new(
     243              :                 "could not create temporary file to back shmem area: {e}",
     244              :                 nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)),
     245              :             )
     246              :         })?;
     247              :         Ok(OwnedFd::from(file))
     248              :     }
     249            3 : }
     250              : 
     251            8 : fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> {
     252              :     // Use posix_fallocate() to enlarge the file. It reserves the space correctly, so that
     253              :     // we don't get a segfault later when trying to actually use it.
     254              :     #[cfg(not(target_os = "macos"))]
     255              :     {
     256            8 :         nix::fcntl::posix_fallocate(fd, 0, size as i64).map_err(|e| {
     257            0 :             Error::new(
     258            0 :                 "could not grow shmem segment, posix_fallocate failed: {e}",
     259            0 :                 e,
     260              :             )
     261            0 :         })
     262              :     }
     263              :     // As a fallback on macos, which doesn't have posix_fallocate, use plain 'fallocate'
     264              :     #[cfg(target_os = "macos")]
     265              :     {
     266              :         nix::unistd::ftruncate(fd, size as i64)
     267              :             .map_err(|e| Error::new("could not grow shmem segment, ftruncate failed: {e}", e))
     268              :     }
     269            8 : }
     270              : 
     271              : #[cfg(test)]
     272              : mod tests {
     273              :     use super::*;
     274              : 
     275              :     use nix::unistd::ForkResult;
     276              :     use std::ops::Range;
     277              : 
     278              :     /// check that all bytes in given range have the expected value.
     279           13 :     fn assert_range(ptr: *const u8, expected: u8, range: Range<usize>) {
     280        26013 :         for i in range {
     281        26000 :             let b = unsafe { *(ptr.add(i)) };
     282        26000 :             assert_eq!(expected, b, "unexpected byte at offset {i}");
     283              :         }
     284           13 :     }
     285              : 
     286              :     /// Write 'b' to all bytes in the given range
     287            5 :     fn write_range(ptr: *mut u8, b: u8, range: Range<usize>) {
     288            5 :         unsafe { std::ptr::write_bytes(ptr.add(range.start), b, range.end - range.start) };
     289            5 :     }
     290              : 
     291              :     // simple single-process test of growing and shrinking
     292              :     #[test]
     293            1 :     fn test_shmem_resize() -> Result<(), Error> {
     294            1 :         let max_size = 1024 * 1024;
     295            1 :         let init_struct = ShmemHandle::new("test_shmem_resize", 0, max_size)?;
     296              : 
     297            1 :         assert_eq!(init_struct.current_size(), 0);
     298              : 
     299              :         // Initial grow
     300            1 :         let size1 = 10000;
     301            1 :         init_struct.set_size(size1).unwrap();
     302            1 :         assert_eq!(init_struct.current_size(), size1);
     303              : 
     304              :         // Write some data
     305            1 :         let data_ptr = init_struct.data_ptr.as_ptr();
     306            1 :         write_range(data_ptr, 0xAA, 0..size1);
     307            1 :         assert_range(data_ptr, 0xAA, 0..size1);
     308              : 
     309              :         // Shrink
     310            1 :         let size2 = 5000;
     311            1 :         init_struct.set_size(size2).unwrap();
     312            1 :         assert_eq!(init_struct.current_size(), size2);
     313              : 
     314              :         // Grow again
     315            1 :         let size3 = 20000;
     316            1 :         init_struct.set_size(size3).unwrap();
     317            1 :         assert_eq!(init_struct.current_size(), size3);
     318              : 
     319              :         // Try to read it. The area that was shrunk and grown again should read as all zeros now
     320            1 :         assert_range(data_ptr, 0xAA, 0..5000);
     321            1 :         assert_range(data_ptr, 0, 5000..size1);
     322              : 
     323              :         // Try to grow beyond max_size
     324              :         //let size4 = max_size + 1;
     325              :         //assert!(init_struct.set_size(size4).is_err());
     326              : 
     327              :         // Dropping init_struct should unmap the memory
     328            1 :         drop(init_struct);
     329              : 
     330            1 :         Ok(())
     331            1 :     }
     332              : 
     333              :     /// This is used in tests to coordinate between test processes. It's like std::sync::Barrier,
     334              :     /// but is stored in the shared memory area and works across processes. It's implemented by
     335              :     /// polling, because e.g. standard rust mutexes are not guaranteed to work across processes.
     336              :     struct SimpleBarrier {
     337              :         num_procs: usize,
     338              :         count: AtomicUsize,
     339              :     }
     340              : 
     341              :     impl SimpleBarrier {
     342            2 :         unsafe fn init(ptr: *mut SimpleBarrier, num_procs: usize) {
     343              :             unsafe {
     344            2 :                 *ptr = SimpleBarrier {
     345            2 :                     num_procs,
     346            2 :                     count: AtomicUsize::new(0),
     347            2 :                 }
     348              :             }
     349            2 :         }
     350              : 
     351            6 :         pub fn wait(&self) {
     352            6 :             let old = self.count.fetch_add(1, Ordering::Relaxed);
     353              : 
     354            6 :             let generation = old / self.num_procs;
     355              : 
     356            6 :             let mut current = old + 1;
     357            9 :             while current < (generation + 1) * self.num_procs {
     358            3 :                 std::thread::sleep(std::time::Duration::from_millis(10));
     359            3 :                 current = self.count.load(Ordering::Relaxed);
     360            3 :             }
     361            6 :         }
     362              :     }
     363              : 
     364              :     #[test]
     365            2 :     fn test_multi_process() {
     366              :         // Initialize
     367            2 :         let max_size = 1_000_000_000_000;
     368            2 :         let init_struct = ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
     369            2 :         let ptr = init_struct.data_ptr.as_ptr();
     370              : 
     371              :         // Store the SimpleBarrier in the first 1k of the area.
     372            2 :         init_struct.set_size(10000).unwrap();
     373            2 :         let barrier_ptr: *mut SimpleBarrier = unsafe {
     374            2 :             ptr.add(ptr.align_offset(std::mem::align_of::<SimpleBarrier>()))
     375            2 :                 .cast()
     376              :         };
     377            2 :         unsafe { SimpleBarrier::init(barrier_ptr, 2) };
     378            2 :         let barrier = unsafe { barrier_ptr.as_ref().unwrap() };
     379              : 
     380              :         // Fork another test process. The code after this runs in both processes concurrently.
     381            2 :         let fork_result = unsafe { nix::unistd::fork().unwrap() };
     382              : 
     383              :         // In the parent, fill bytes between 1000..2000. In the child, between 2000..3000
     384            2 :         if fork_result.is_parent() {
     385            1 :             write_range(ptr, 0xAA, 1000..2000);
     386            1 :         } else {
     387            1 :             write_range(ptr, 0xBB, 2000..3000);
     388            1 :         }
     389            2 :         barrier.wait();
     390              :         // Verify the contents. (in both processes)
     391            2 :         assert_range(ptr, 0xAA, 1000..2000);
     392            2 :         assert_range(ptr, 0xBB, 2000..3000);
     393              : 
     394              :         // Grow, from the child this time
     395            2 :         let size = 10_000_000;
     396            2 :         if !fork_result.is_parent() {
     397            1 :             init_struct.set_size(size).unwrap();
     398            1 :         }
     399            2 :         barrier.wait();
     400              : 
     401              :         // make some writes at the end
     402            2 :         if fork_result.is_parent() {
     403            1 :             write_range(ptr, 0xAA, (size - 10)..size);
     404            1 :         } else {
     405            1 :             write_range(ptr, 0xBB, (size - 20)..(size - 10));
     406            1 :         }
     407            2 :         barrier.wait();
     408              : 
     409              :         // Verify the contents. (This runs in both processes)
     410            2 :         assert_range(ptr, 0, (size - 1000)..(size - 20));
     411            2 :         assert_range(ptr, 0xBB, (size - 20)..(size - 10));
     412            2 :         assert_range(ptr, 0xAA, (size - 10)..size);
     413              : 
     414            2 :         if let ForkResult::Parent { child } = fork_result {
     415            1 :             nix::sys::wait::waitpid(child, None).unwrap();
     416            1 :         }
     417            2 :     }
     418              : }
        

Generated by: LCOV version 2.1-beta