Line data Source code
1 : //! Shared memory utilities for neon communicator
2 :
3 : use std::num::NonZeroUsize;
4 : use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
5 : use std::ptr::NonNull;
6 : use std::sync::atomic::{AtomicUsize, Ordering};
7 :
8 : use nix::errno::Errno;
9 : use nix::sys::mman::MapFlags;
10 : use nix::sys::mman::ProtFlags;
11 : use nix::sys::mman::mmap as nix_mmap;
12 : use nix::sys::mman::munmap as nix_munmap;
13 : use nix::unistd::ftruncate as nix_ftruncate;
14 :
15 : /// ShmemHandle represents a shared memory area that can be shared by processes over fork().
16 : /// Unlike shared memory allocated by Postgres, this area is resizable, up to 'max_size' that's
17 : /// specified at creation.
18 : ///
19 : /// The area is backed by an anonymous file created with memfd_create(). The full address space for
20 : /// 'max_size' is reserved up-front with mmap(), but whenever you call [`ShmemHandle::set_size`],
21 : /// the underlying file is resized. Do not access the area beyond the current size. Currently, that
22 : /// will cause the file to be expanded, but we might use mprotect() etc. to enforce that in the
23 : /// future.
24 : pub struct ShmemHandle {
25 : /// memfd file descriptor
26 : fd: OwnedFd,
27 :
28 : max_size: usize,
29 :
30 : // Pointer to the beginning of the shared memory area. The header is stored there.
31 : shared_ptr: NonNull<SharedStruct>,
32 :
33 : // Pointer to the beginning of the user data
34 : pub data_ptr: NonNull<u8>,
35 : }
36 :
37 : /// This is stored at the beginning in the shared memory area.
38 : struct SharedStruct {
39 : max_size: usize,
40 :
41 : /// Current size of the backing file. The high-order bit is used for the RESIZE_IN_PROGRESS flag
42 : current_size: AtomicUsize,
43 : }
44 :
45 : const RESIZE_IN_PROGRESS: usize = 1 << 63;
46 :
47 : const HEADER_SIZE: usize = std::mem::size_of::<SharedStruct>();
48 :
49 : /// Error type returned by the ShmemHandle functions.
50 : #[derive(thiserror::Error, Debug)]
51 : #[error("{msg}: {errno}")]
52 : pub struct Error {
53 : pub msg: String,
54 : pub errno: Errno,
55 : }
56 :
57 : impl Error {
58 0 : fn new(msg: &str, errno: Errno) -> Error {
59 0 : Error {
60 0 : msg: msg.to_string(),
61 0 : errno,
62 0 : }
63 0 : }
64 : }
65 :
66 : impl ShmemHandle {
67 : /// Create a new shared memory area. To communicate between processes, the processes need to be
68 : /// fork()'d after calling this, so that the ShmemHandle is inherited by all processes.
69 : ///
70 : /// If the ShmemHandle is dropped, the memory is unmapped from the current process. Other
71 : /// processes can continue using it, however.
72 3 : pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result<ShmemHandle, Error> {
73 : // create the backing anonymous file.
74 3 : let fd = create_backing_file(name)?;
75 :
76 3 : Self::new_with_fd(fd, initial_size, max_size)
77 3 : }
78 :
79 3 : fn new_with_fd(
80 3 : fd: OwnedFd,
81 3 : initial_size: usize,
82 3 : max_size: usize,
83 3 : ) -> Result<ShmemHandle, Error> {
84 : // We reserve the high-order bit for the RESIZE_IN_PROGRESS flag, and the actual size
85 : // is a little larger than this because of the SharedStruct header. Make the upper limit
86 : // somewhat smaller than that, because with anything close to that, you'll run out of
87 : // memory anyway.
88 3 : if max_size >= 1 << 48 {
89 0 : panic!("max size {max_size} too large");
90 3 : }
91 3 : if initial_size > max_size {
92 0 : panic!("initial size {initial_size} larger than max size {max_size}");
93 3 : }
94 :
95 : // The actual initial / max size is the one given by the caller, plus the size of
96 : // 'SharedStruct'.
97 3 : let initial_size = HEADER_SIZE + initial_size;
98 3 : let max_size = NonZeroUsize::new(HEADER_SIZE + max_size).unwrap();
99 :
100 : // Reserve address space for it with mmap
101 : //
102 : // TODO: Use MAP_HUGETLB if possible
103 3 : let start_ptr = unsafe {
104 3 : nix_mmap(
105 3 : None,
106 3 : max_size,
107 3 : ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
108 : MapFlags::MAP_SHARED,
109 3 : &fd,
110 : 0,
111 : )
112 : }
113 3 : .map_err(|e| Error::new("mmap failed: {e}", e))?;
114 :
115 : // Reserve space for the initial size
116 3 : enlarge_file(fd.as_fd(), initial_size as u64)?;
117 :
118 : // Initialize the header
119 3 : let shared: NonNull<SharedStruct> = start_ptr.cast();
120 : unsafe {
121 3 : shared.write(SharedStruct {
122 3 : max_size: max_size.into(),
123 3 : current_size: AtomicUsize::new(initial_size),
124 3 : })
125 : };
126 :
127 : // The user data begins after the header
128 3 : let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) };
129 :
130 3 : Ok(ShmemHandle {
131 3 : fd,
132 3 : max_size: max_size.into(),
133 3 : shared_ptr: shared,
134 3 : data_ptr,
135 3 : })
136 3 : }
137 :
138 : // return reference to the header
139 10 : fn shared(&self) -> &SharedStruct {
140 10 : unsafe { self.shared_ptr.as_ref() }
141 10 : }
142 :
143 : /// Resize the shared memory area. 'new_size' must not be larger than the 'max_size' specified
144 : /// when creating the area.
145 : ///
146 : /// This may only be called from one process/thread concurrently. We detect that case
147 : /// and return an Error.
148 6 : pub fn set_size(&self, new_size: usize) -> Result<(), Error> {
149 6 : let new_size = new_size + HEADER_SIZE;
150 6 : let shared = self.shared();
151 :
152 6 : if new_size > self.max_size {
153 0 : panic!(
154 0 : "new size ({} is greater than max size ({})",
155 : new_size, self.max_size
156 : );
157 6 : }
158 6 : assert_eq!(self.max_size, shared.max_size);
159 :
160 : // Lock the area by setting the bit in 'current_size'
161 : //
162 : // Ordering::Relaxed would probably be sufficient here, as we don't access any other memory
163 : // and the posix_fallocate/ftruncate call is surely a synchronization point anyway. But
164 : // since this is not performance-critical, better safe than sorry .
165 6 : let mut old_size = shared.current_size.load(Ordering::Acquire);
166 : loop {
167 6 : if (old_size & RESIZE_IN_PROGRESS) != 0 {
168 0 : return Err(Error::new(
169 0 : "concurrent resize detected",
170 0 : Errno::UnknownErrno,
171 0 : ));
172 6 : }
173 6 : match shared.current_size.compare_exchange(
174 6 : old_size,
175 6 : new_size,
176 6 : Ordering::Acquire,
177 6 : Ordering::Relaxed,
178 6 : ) {
179 6 : Ok(_) => break,
180 0 : Err(x) => old_size = x,
181 : }
182 : }
183 :
184 : // Ok, we got the lock.
185 : //
186 : // NB: If anything goes wrong, we *must* clear the bit!
187 6 : let result = {
188 : use std::cmp::Ordering::{Equal, Greater, Less};
189 6 : match new_size.cmp(&old_size) {
190 1 : Less => nix_ftruncate(&self.fd, new_size as i64).map_err(|e| {
191 0 : Error::new("could not shrink shmem segment, ftruncate failed: {e}", e)
192 0 : }),
193 0 : Equal => Ok(()),
194 5 : Greater => enlarge_file(self.fd.as_fd(), new_size as u64),
195 : }
196 : };
197 :
198 : // Unlock
199 6 : shared.current_size.store(
200 6 : if result.is_ok() { new_size } else { old_size },
201 6 : Ordering::Release,
202 : );
203 :
204 6 : result
205 6 : }
206 :
207 : /// Returns the current user-visible size of the shared memory segment.
208 : ///
209 : /// NOTE: a concurrent set_size() call can change the size at any time. It is the caller's
210 : /// responsibility not to access the area beyond the current size.
211 4 : pub fn current_size(&self) -> usize {
212 4 : let total_current_size =
213 4 : self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS;
214 4 : total_current_size - HEADER_SIZE
215 4 : }
216 : }
217 :
218 : impl Drop for ShmemHandle {
219 3 : fn drop(&mut self) {
220 : // SAFETY: The pointer was obtained from mmap() with the given size.
221 : // We unmap the entire region.
222 3 : let _ = unsafe { nix_munmap(self.shared_ptr.cast(), self.max_size) };
223 : // The fd is dropped automatically by OwnedFd.
224 3 : }
225 : }
226 :
227 : /// Create a "backing file" for the shared memory area. On Linux, use memfd_create(), to create an
228 : /// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for
229 : /// development and testing, but in production we want the file to stay in memory.
230 : ///
231 : /// disable 'unused_variables' warnings, because in the macos path, 'name' is unused.
232 : #[allow(unused_variables)]
233 3 : fn create_backing_file(name: &str) -> Result<OwnedFd, Error> {
234 : #[cfg(not(target_os = "macos"))]
235 : {
236 3 : nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty())
237 3 : .map_err(|e| Error::new("memfd_create failed: {e}", e))
238 : }
239 : #[cfg(target_os = "macos")]
240 : {
241 : let file = tempfile::tempfile().map_err(|e| {
242 : Error::new(
243 : "could not create temporary file to back shmem area: {e}",
244 : nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)),
245 : )
246 : })?;
247 : Ok(OwnedFd::from(file))
248 : }
249 3 : }
250 :
251 8 : fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> {
252 : // Use posix_fallocate() to enlarge the file. It reserves the space correctly, so that
253 : // we don't get a segfault later when trying to actually use it.
254 : #[cfg(not(target_os = "macos"))]
255 : {
256 8 : nix::fcntl::posix_fallocate(fd, 0, size as i64).map_err(|e| {
257 0 : Error::new(
258 0 : "could not grow shmem segment, posix_fallocate failed: {e}",
259 0 : e,
260 : )
261 0 : })
262 : }
263 : // As a fallback on macos, which doesn't have posix_fallocate, use plain 'fallocate'
264 : #[cfg(target_os = "macos")]
265 : {
266 : nix::unistd::ftruncate(fd, size as i64)
267 : .map_err(|e| Error::new("could not grow shmem segment, ftruncate failed: {e}", e))
268 : }
269 8 : }
270 :
271 : #[cfg(test)]
272 : mod tests {
273 : use super::*;
274 :
275 : use nix::unistd::ForkResult;
276 : use std::ops::Range;
277 :
278 : /// check that all bytes in given range have the expected value.
279 13 : fn assert_range(ptr: *const u8, expected: u8, range: Range<usize>) {
280 26013 : for i in range {
281 26000 : let b = unsafe { *(ptr.add(i)) };
282 26000 : assert_eq!(expected, b, "unexpected byte at offset {i}");
283 : }
284 13 : }
285 :
286 : /// Write 'b' to all bytes in the given range
287 5 : fn write_range(ptr: *mut u8, b: u8, range: Range<usize>) {
288 5 : unsafe { std::ptr::write_bytes(ptr.add(range.start), b, range.end - range.start) };
289 5 : }
290 :
291 : // simple single-process test of growing and shrinking
292 : #[test]
293 1 : fn test_shmem_resize() -> Result<(), Error> {
294 1 : let max_size = 1024 * 1024;
295 1 : let init_struct = ShmemHandle::new("test_shmem_resize", 0, max_size)?;
296 :
297 1 : assert_eq!(init_struct.current_size(), 0);
298 :
299 : // Initial grow
300 1 : let size1 = 10000;
301 1 : init_struct.set_size(size1).unwrap();
302 1 : assert_eq!(init_struct.current_size(), size1);
303 :
304 : // Write some data
305 1 : let data_ptr = init_struct.data_ptr.as_ptr();
306 1 : write_range(data_ptr, 0xAA, 0..size1);
307 1 : assert_range(data_ptr, 0xAA, 0..size1);
308 :
309 : // Shrink
310 1 : let size2 = 5000;
311 1 : init_struct.set_size(size2).unwrap();
312 1 : assert_eq!(init_struct.current_size(), size2);
313 :
314 : // Grow again
315 1 : let size3 = 20000;
316 1 : init_struct.set_size(size3).unwrap();
317 1 : assert_eq!(init_struct.current_size(), size3);
318 :
319 : // Try to read it. The area that was shrunk and grown again should read as all zeros now
320 1 : assert_range(data_ptr, 0xAA, 0..5000);
321 1 : assert_range(data_ptr, 0, 5000..size1);
322 :
323 : // Try to grow beyond max_size
324 : //let size4 = max_size + 1;
325 : //assert!(init_struct.set_size(size4).is_err());
326 :
327 : // Dropping init_struct should unmap the memory
328 1 : drop(init_struct);
329 :
330 1 : Ok(())
331 1 : }
332 :
333 : /// This is used in tests to coordinate between test processes. It's like std::sync::Barrier,
334 : /// but is stored in the shared memory area and works across processes. It's implemented by
335 : /// polling, because e.g. standard rust mutexes are not guaranteed to work across processes.
336 : struct SimpleBarrier {
337 : num_procs: usize,
338 : count: AtomicUsize,
339 : }
340 :
341 : impl SimpleBarrier {
342 2 : unsafe fn init(ptr: *mut SimpleBarrier, num_procs: usize) {
343 : unsafe {
344 2 : *ptr = SimpleBarrier {
345 2 : num_procs,
346 2 : count: AtomicUsize::new(0),
347 2 : }
348 : }
349 2 : }
350 :
351 6 : pub fn wait(&self) {
352 6 : let old = self.count.fetch_add(1, Ordering::Relaxed);
353 :
354 6 : let generation = old / self.num_procs;
355 :
356 6 : let mut current = old + 1;
357 9 : while current < (generation + 1) * self.num_procs {
358 3 : std::thread::sleep(std::time::Duration::from_millis(10));
359 3 : current = self.count.load(Ordering::Relaxed);
360 3 : }
361 6 : }
362 : }
363 :
364 : #[test]
365 2 : fn test_multi_process() {
366 : // Initialize
367 2 : let max_size = 1_000_000_000_000;
368 2 : let init_struct = ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
369 2 : let ptr = init_struct.data_ptr.as_ptr();
370 :
371 : // Store the SimpleBarrier in the first 1k of the area.
372 2 : init_struct.set_size(10000).unwrap();
373 2 : let barrier_ptr: *mut SimpleBarrier = unsafe {
374 2 : ptr.add(ptr.align_offset(std::mem::align_of::<SimpleBarrier>()))
375 2 : .cast()
376 : };
377 2 : unsafe { SimpleBarrier::init(barrier_ptr, 2) };
378 2 : let barrier = unsafe { barrier_ptr.as_ref().unwrap() };
379 :
380 : // Fork another test process. The code after this runs in both processes concurrently.
381 2 : let fork_result = unsafe { nix::unistd::fork().unwrap() };
382 :
383 : // In the parent, fill bytes between 1000..2000. In the child, between 2000..3000
384 2 : if fork_result.is_parent() {
385 1 : write_range(ptr, 0xAA, 1000..2000);
386 1 : } else {
387 1 : write_range(ptr, 0xBB, 2000..3000);
388 1 : }
389 2 : barrier.wait();
390 : // Verify the contents. (in both processes)
391 2 : assert_range(ptr, 0xAA, 1000..2000);
392 2 : assert_range(ptr, 0xBB, 2000..3000);
393 :
394 : // Grow, from the child this time
395 2 : let size = 10_000_000;
396 2 : if !fork_result.is_parent() {
397 1 : init_struct.set_size(size).unwrap();
398 1 : }
399 2 : barrier.wait();
400 :
401 : // make some writes at the end
402 2 : if fork_result.is_parent() {
403 1 : write_range(ptr, 0xAA, (size - 10)..size);
404 1 : } else {
405 1 : write_range(ptr, 0xBB, (size - 20)..(size - 10));
406 1 : }
407 2 : barrier.wait();
408 :
409 : // Verify the contents. (This runs in both processes)
410 2 : assert_range(ptr, 0, (size - 1000)..(size - 20));
411 2 : assert_range(ptr, 0xBB, (size - 20)..(size - 10));
412 2 : assert_range(ptr, 0xAA, (size - 10)..size);
413 :
414 2 : if let ForkResult::Parent { child } = fork_result {
415 1 : nix::sys::wait::waitpid(child, None).unwrap();
416 1 : }
417 2 : }
418 : }
|