Line data Source code
1 : //! Dynamically resizable contiguous chunk of shared memory
2 :
3 : use std::num::NonZeroUsize;
4 : use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
5 : use std::ptr::NonNull;
6 : use std::sync::atomic::{AtomicUsize, Ordering};
7 :
8 : use nix::errno::Errno;
9 : use nix::sys::mman::MapFlags;
10 : use nix::sys::mman::ProtFlags;
11 : use nix::sys::mman::mmap as nix_mmap;
12 : use nix::sys::mman::munmap as nix_munmap;
13 : use nix::unistd::ftruncate as nix_ftruncate;
14 :
15 : /// `ShmemHandle` represents a shared memory area that can be shared by processes over `fork()`.
16 : /// Unlike shared memory allocated by Postgres, this area is resizable, up to `max_size` that's
17 : /// specified at creation.
18 : ///
19 : /// The area is backed by an anonymous file created with `memfd_create()`. The full address space for
20 : /// `max_size` is reserved up-front with `mmap()`, but whenever you call [`ShmemHandle::set_size`],
21 : /// the underlying file is resized. Do not access the area beyond the current size. Currently, that
22 : /// will cause the file to be expanded, but we might use `mprotect()` etc. to enforce that in the
23 : /// future.
24 : pub struct ShmemHandle {
25 : /// memfd file descriptor
26 : fd: OwnedFd,
27 :
28 : max_size: usize,
29 :
30 : // Pointer to the beginning of the shared memory area. The header is stored there.
31 : shared_ptr: NonNull<SharedStruct>,
32 :
33 : // Pointer to the beginning of the user data
34 : pub data_ptr: NonNull<u8>,
35 : }
36 :
37 : /// This is stored at the beginning in the shared memory area.
38 : struct SharedStruct {
39 : max_size: usize,
40 :
41 : /// Current size of the backing file. The high-order bit is used for the [`RESIZE_IN_PROGRESS`] flag.
42 : current_size: AtomicUsize,
43 : }
44 :
45 : const RESIZE_IN_PROGRESS: usize = 1 << 63;
46 :
47 : const HEADER_SIZE: usize = std::mem::size_of::<SharedStruct>();
48 :
49 : /// Error type returned by the [`ShmemHandle`] functions.
50 : #[derive(thiserror::Error, Debug)]
51 : #[error("{msg}: {errno}")]
52 : pub struct Error {
53 : pub msg: String,
54 : pub errno: Errno,
55 : }
56 :
57 : impl Error {
58 0 : fn new(msg: &str, errno: Errno) -> Self {
59 0 : Self {
60 0 : msg: msg.to_string(),
61 0 : errno,
62 0 : }
63 0 : }
64 : }
65 :
66 : impl ShmemHandle {
67 : /// Create a new shared memory area. To communicate between processes, the processes need to be
68 : /// `fork()`'d after calling this, so that the `ShmemHandle` is inherited by all processes.
69 : ///
70 : /// If the `ShmemHandle` is dropped, the memory is unmapped from the current process. Other
71 : /// processes can continue using it, however.
72 3 : pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result<Self, Error> {
73 : // create the backing anonymous file.
74 3 : let fd = create_backing_file(name)?;
75 :
76 3 : Self::new_with_fd(fd, initial_size, max_size)
77 3 : }
78 :
79 3 : fn new_with_fd(fd: OwnedFd, initial_size: usize, max_size: usize) -> Result<Self, Error> {
80 : // We reserve the high-order bit for the `RESIZE_IN_PROGRESS` flag, and the actual size
81 : // is a little larger than this because of the SharedStruct header. Make the upper limit
82 : // somewhat smaller than that, because with anything close to that, you'll run out of
83 : // memory anyway.
84 3 : assert!(max_size < 1 << 48, "max size {max_size} too large");
85 :
86 3 : assert!(
87 3 : initial_size <= max_size,
88 0 : "initial size {initial_size} larger than max size {max_size}"
89 : );
90 :
91 : // The actual initial / max size is the one given by the caller, plus the size of
92 : // 'SharedStruct'.
93 3 : let initial_size = HEADER_SIZE + initial_size;
94 3 : let max_size = NonZeroUsize::new(HEADER_SIZE + max_size).unwrap();
95 :
96 : // Reserve address space for it with mmap
97 : //
98 : // TODO: Use MAP_HUGETLB if possible
99 3 : let start_ptr = unsafe {
100 3 : nix_mmap(
101 3 : None,
102 3 : max_size,
103 3 : ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
104 : MapFlags::MAP_SHARED,
105 3 : &fd,
106 : 0,
107 : )
108 : }
109 3 : .map_err(|e| Error::new("mmap failed", e))?;
110 :
111 : // Reserve space for the initial size
112 3 : enlarge_file(fd.as_fd(), initial_size as u64)?;
113 :
114 : // Initialize the header
115 3 : let shared: NonNull<SharedStruct> = start_ptr.cast();
116 3 : unsafe {
117 3 : shared.write(SharedStruct {
118 3 : max_size: max_size.into(),
119 3 : current_size: AtomicUsize::new(initial_size),
120 3 : });
121 3 : }
122 :
123 : // The user data begins after the header
124 3 : let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) };
125 :
126 3 : Ok(Self {
127 3 : fd,
128 3 : max_size: max_size.into(),
129 3 : shared_ptr: shared,
130 3 : data_ptr,
131 3 : })
132 3 : }
133 :
134 : // return reference to the header
135 10 : fn shared(&self) -> &SharedStruct {
136 10 : unsafe { self.shared_ptr.as_ref() }
137 10 : }
138 :
139 : /// Resize the shared memory area. `new_size` must not be larger than the `max_size` specified
140 : /// when creating the area.
141 : ///
142 : /// This may only be called from one process/thread concurrently. We detect that case
143 : /// and return an [`shmem::Error`](Error).
144 6 : pub fn set_size(&self, new_size: usize) -> Result<(), Error> {
145 6 : let new_size = new_size + HEADER_SIZE;
146 6 : let shared = self.shared();
147 :
148 6 : assert!(
149 6 : new_size <= self.max_size,
150 0 : "new size ({new_size}) is greater than max size ({})",
151 : self.max_size
152 : );
153 :
154 6 : assert_eq!(self.max_size, shared.max_size);
155 :
156 : // Lock the area by setting the bit in `current_size`
157 : //
158 : // Ordering::Relaxed would probably be sufficient here, as we don't access any other memory
159 : // and the `posix_fallocate`/`ftruncate` call is surely a synchronization point anyway. But
160 : // since this is not performance-critical, better safe than sorry.
161 6 : let mut old_size = shared.current_size.load(Ordering::Acquire);
162 : loop {
163 6 : if (old_size & RESIZE_IN_PROGRESS) != 0 {
164 0 : return Err(Error::new(
165 0 : "concurrent resize detected",
166 0 : Errno::UnknownErrno,
167 0 : ));
168 6 : }
169 6 : match shared.current_size.compare_exchange(
170 6 : old_size,
171 6 : new_size,
172 6 : Ordering::Acquire,
173 6 : Ordering::Relaxed,
174 6 : ) {
175 6 : Ok(_) => break,
176 0 : Err(x) => old_size = x,
177 : }
178 : }
179 :
180 : // Ok, we got the lock.
181 : //
182 : // NB: If anything goes wrong, we *must* clear the bit!
183 6 : let result = {
184 : use std::cmp::Ordering::{Equal, Greater, Less};
185 6 : match new_size.cmp(&old_size) {
186 1 : Less => nix_ftruncate(&self.fd, new_size as i64)
187 1 : .map_err(|e| Error::new("could not shrink shmem segment, ftruncate failed", e)),
188 0 : Equal => Ok(()),
189 5 : Greater => enlarge_file(self.fd.as_fd(), new_size as u64),
190 : }
191 : };
192 :
193 : // Unlock
194 6 : shared.current_size.store(
195 6 : if result.is_ok() { new_size } else { old_size },
196 6 : Ordering::Release,
197 : );
198 :
199 6 : result
200 6 : }
201 :
202 : /// Returns the current user-visible size of the shared memory segment.
203 : ///
204 : /// NOTE: a concurrent [`ShmemHandle::set_size()`] call can change the size at any time.
205 : /// It is the caller's responsibility not to access the area beyond the current size.
206 4 : pub fn current_size(&self) -> usize {
207 4 : let total_current_size =
208 4 : self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS;
209 4 : total_current_size - HEADER_SIZE
210 4 : }
211 : }
212 :
213 : impl Drop for ShmemHandle {
214 3 : fn drop(&mut self) {
215 : // SAFETY: The pointer was obtained from mmap() with the given size.
216 : // We unmap the entire region.
217 3 : let _ = unsafe { nix_munmap(self.shared_ptr.cast(), self.max_size) };
218 : // The fd is dropped automatically by OwnedFd.
219 3 : }
220 : }
221 :
222 : /// Create a "backing file" for the shared memory area. On Linux, use `memfd_create()`, to create an
223 : /// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for
224 : /// development and testing, but in production we want the file to stay in memory.
225 : ///
226 : /// Disable unused variables warnings because `name` is unused in the macos path.
227 : #[allow(unused_variables)]
228 3 : fn create_backing_file(name: &str) -> Result<OwnedFd, Error> {
229 : #[cfg(not(target_os = "macos"))]
230 : {
231 3 : nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty())
232 3 : .map_err(|e| Error::new("memfd_create failed", e))
233 : }
234 : #[cfg(target_os = "macos")]
235 : {
236 : let file = tempfile::tempfile().map_err(|e| {
237 : Error::new(
238 : "could not create temporary file to back shmem area",
239 : nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)),
240 : )
241 : })?;
242 : Ok(OwnedFd::from(file))
243 : }
244 3 : }
245 :
246 8 : fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> {
247 : // Use posix_fallocate() to enlarge the file. It reserves the space correctly, so that
248 : // we don't get a segfault later when trying to actually use it.
249 : #[cfg(not(target_os = "macos"))]
250 : {
251 8 : nix::fcntl::posix_fallocate(fd, 0, size as i64)
252 8 : .map_err(|e| Error::new("could not grow shmem segment, posix_fallocate failed", e))
253 : }
254 : // As a fallback on macos, which doesn't have posix_fallocate, use plain 'fallocate'
255 : #[cfg(target_os = "macos")]
256 : {
257 : nix::unistd::ftruncate(fd, size as i64)
258 : .map_err(|e| Error::new("could not grow shmem segment, ftruncate failed", e))
259 : }
260 8 : }
261 :
262 : #[cfg(test)]
263 : mod tests {
264 : use super::*;
265 :
266 : use nix::unistd::ForkResult;
267 : use std::ops::Range;
268 :
269 : /// check that all bytes in given range have the expected value.
270 13 : fn assert_range(ptr: *const u8, expected: u8, range: Range<usize>) {
271 26013 : for i in range {
272 26000 : let b = unsafe { *(ptr.add(i)) };
273 26000 : assert_eq!(expected, b, "unexpected byte at offset {i}");
274 : }
275 13 : }
276 :
277 : /// Write 'b' to all bytes in the given range
278 5 : fn write_range(ptr: *mut u8, b: u8, range: Range<usize>) {
279 5 : unsafe { std::ptr::write_bytes(ptr.add(range.start), b, range.end - range.start) };
280 5 : }
281 :
282 : // simple single-process test of growing and shrinking
283 : #[test]
284 1 : fn test_shmem_resize() -> Result<(), Error> {
285 1 : let max_size = 1024 * 1024;
286 1 : let init_struct = ShmemHandle::new("test_shmem_resize", 0, max_size)?;
287 :
288 1 : assert_eq!(init_struct.current_size(), 0);
289 :
290 : // Initial grow
291 1 : let size1 = 10000;
292 1 : init_struct.set_size(size1).unwrap();
293 1 : assert_eq!(init_struct.current_size(), size1);
294 :
295 : // Write some data
296 1 : let data_ptr = init_struct.data_ptr.as_ptr();
297 1 : write_range(data_ptr, 0xAA, 0..size1);
298 1 : assert_range(data_ptr, 0xAA, 0..size1);
299 :
300 : // Shrink
301 1 : let size2 = 5000;
302 1 : init_struct.set_size(size2).unwrap();
303 1 : assert_eq!(init_struct.current_size(), size2);
304 :
305 : // Grow again
306 1 : let size3 = 20000;
307 1 : init_struct.set_size(size3).unwrap();
308 1 : assert_eq!(init_struct.current_size(), size3);
309 :
310 : // Try to read it. The area that was shrunk and grown again should read as all zeros now
311 1 : assert_range(data_ptr, 0xAA, 0..5000);
312 1 : assert_range(data_ptr, 0, 5000..size1);
313 :
314 : // Try to grow beyond max_size
315 : //let size4 = max_size + 1;
316 : //assert!(init_struct.set_size(size4).is_err());
317 :
318 : // Dropping init_struct should unmap the memory
319 1 : drop(init_struct);
320 :
321 1 : Ok(())
322 1 : }
323 :
324 : /// This is used in tests to coordinate between test processes. It's like `std::sync::Barrier`,
325 : /// but is stored in the shared memory area and works across processes. It's implemented by
326 : /// polling, because e.g. standard rust mutexes are not guaranteed to work across processes.
327 : struct SimpleBarrier {
328 : num_procs: usize,
329 : count: AtomicUsize,
330 : }
331 :
332 : impl SimpleBarrier {
333 2 : unsafe fn init(ptr: *mut SimpleBarrier, num_procs: usize) {
334 : unsafe {
335 2 : *ptr = SimpleBarrier {
336 2 : num_procs,
337 2 : count: AtomicUsize::new(0),
338 2 : }
339 : }
340 2 : }
341 :
342 6 : pub fn wait(&self) {
343 6 : let old = self.count.fetch_add(1, Ordering::Relaxed);
344 :
345 6 : let generation = old / self.num_procs;
346 :
347 6 : let mut current = old + 1;
348 9 : while current < (generation + 1) * self.num_procs {
349 3 : std::thread::sleep(std::time::Duration::from_millis(10));
350 3 : current = self.count.load(Ordering::Relaxed);
351 3 : }
352 6 : }
353 : }
354 :
355 : #[test]
356 2 : fn test_multi_process() {
357 : // Initialize
358 2 : let max_size = 1_000_000_000_000;
359 2 : let init_struct = ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
360 2 : let ptr = init_struct.data_ptr.as_ptr();
361 :
362 : // Store the SimpleBarrier in the first 1k of the area.
363 2 : init_struct.set_size(10000).unwrap();
364 2 : let barrier_ptr: *mut SimpleBarrier = unsafe {
365 2 : ptr.add(ptr.align_offset(std::mem::align_of::<SimpleBarrier>()))
366 2 : .cast()
367 : };
368 2 : unsafe { SimpleBarrier::init(barrier_ptr, 2) };
369 2 : let barrier = unsafe { barrier_ptr.as_ref().unwrap() };
370 :
371 : // Fork another test process. The code after this runs in both processes concurrently.
372 2 : let fork_result = unsafe { nix::unistd::fork().unwrap() };
373 :
374 : // In the parent, fill bytes between 1000..2000. In the child, between 2000..3000
375 2 : if fork_result.is_parent() {
376 1 : write_range(ptr, 0xAA, 1000..2000);
377 1 : } else {
378 1 : write_range(ptr, 0xBB, 2000..3000);
379 1 : }
380 2 : barrier.wait();
381 : // Verify the contents. (in both processes)
382 2 : assert_range(ptr, 0xAA, 1000..2000);
383 2 : assert_range(ptr, 0xBB, 2000..3000);
384 :
385 : // Grow, from the child this time
386 2 : let size = 10_000_000;
387 2 : if !fork_result.is_parent() {
388 1 : init_struct.set_size(size).unwrap();
389 1 : }
390 2 : barrier.wait();
391 :
392 : // make some writes at the end
393 2 : if fork_result.is_parent() {
394 1 : write_range(ptr, 0xAA, (size - 10)..size);
395 1 : } else {
396 1 : write_range(ptr, 0xBB, (size - 20)..(size - 10));
397 1 : }
398 2 : barrier.wait();
399 :
400 : // Verify the contents. (This runs in both processes)
401 2 : assert_range(ptr, 0, (size - 1000)..(size - 20));
402 2 : assert_range(ptr, 0xBB, (size - 20)..(size - 10));
403 2 : assert_range(ptr, 0xAA, (size - 10)..size);
404 :
405 2 : if let ForkResult::Parent { child } = fork_result {
406 1 : nix::sys::wait::waitpid(child, None).unwrap();
407 1 : }
408 2 : }
409 : }
|