|
| 1 | +// Copyright 2021-Present Datadog, Inc. https://www.datadoghq.com/ |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +//! Lock-free `Option<T>` with atomic take, valid for any `T` where |
| 5 | +//! `size_of::<Option<T>>() <= 8`. |
| 6 | +
|
| 7 | +use std::cell::UnsafeCell; |
| 8 | +use std::mem::{self, MaybeUninit}; |
| 9 | +use std::ptr; |
| 10 | +use std::sync::atomic::{AtomicU16, AtomicU32, AtomicU64, AtomicU8, Ordering}; |
| 11 | + |
| 12 | +/// An `Option<T>` that supports lock-free atomic take. |
| 13 | +/// |
| 14 | +/// # Constraints |
| 15 | +/// `size_of::<Option<T>>()` must be `<= 8`. Enforced by a `debug_assert` in |
| 16 | +/// `From<Option<T>>`). This holds for niche-optimised types (`NonNull<T>`, |
| 17 | +/// `Box<T>`, …) and for any `Option<T>` that fits in a single machine word. |
| 18 | +/// |
| 19 | +/// # Storage |
| 20 | +/// The option is stored in a `UnsafeCell<Option<T>>`, giving it exactly the size |
| 21 | +/// and alignment of `Option<T>` itself. `take()` picks the narrowest atomic that |
| 22 | +/// covers `size_of::<Option<T>>()` bytes — `AtomicU8` for 1-byte options up to |
| 23 | +/// `AtomicU64` for 5–8 byte options. The atomic cast is valid because |
| 24 | +/// `align_of::<AtomicUN>() == align_of::<uN>() <= align_of::<Option<T>>()`. |
| 25 | +/// |
| 26 | +/// # None sentinel |
| 27 | +/// The "none" bit-pattern is computed by value (`Option::<T>::None`) rather than |
| 28 | +/// assumed to be zero, so the implementation is correct for both niche-optimised |
| 29 | +/// types and discriminant-based options. |
| 30 | +/// |
| 31 | +/// `UnsafeCell` provides the interior-mutability aliasing permission required by |
| 32 | +/// Rust's memory model when mutating through a shared reference. |
| 33 | +pub struct AtomicOption<T>(UnsafeCell<Option<T>>); |
| 34 | + |
| 35 | +impl<T> AtomicOption<T> { |
| 36 | + /// Encode `val` as a `u64`, transferring ownership into the bit representation. |
| 37 | + const fn encode(val: Option<T>) -> u64 { |
| 38 | + let mut bits = 0u64; |
| 39 | + unsafe { |
| 40 | + ptr::copy_nonoverlapping( |
| 41 | + ptr::from_ref(&val).cast::<u8>(), |
| 42 | + ptr::from_mut(&mut bits).cast::<u8>(), |
| 43 | + size_of::<Option<T>>(), |
| 44 | + ); |
| 45 | + mem::forget(val); |
| 46 | + } |
| 47 | + bits |
| 48 | + } |
| 49 | + |
| 50 | + /// Atomically swap the storage with `new_bits`, returning the old bits. |
| 51 | + #[inline] |
| 52 | + fn atomic_swap(&self, new_bits: u64) -> u64 { |
| 53 | + unsafe { |
| 54 | + let ptr = self.0.get(); |
| 55 | + match size_of::<Option<T>>() { |
| 56 | + 1 => (*(ptr as *const AtomicU8)).swap(new_bits as u8, Ordering::AcqRel) as u64, |
| 57 | + 2 => (*(ptr as *const AtomicU16)).swap(new_bits as u16, Ordering::AcqRel) as u64, |
| 58 | + 3 | 4 => { |
| 59 | + (*(ptr as *const AtomicU32)).swap(new_bits as u32, Ordering::AcqRel) as u64 |
| 60 | + } |
| 61 | + _ => (*(ptr as *const AtomicU64)).swap(new_bits, Ordering::AcqRel), |
| 62 | + } |
| 63 | + } |
| 64 | + } |
| 65 | + |
| 66 | + /// Reconstruct an `Option<T>` from its `u64` bit representation. |
| 67 | + /// |
| 68 | + /// # Safety |
| 69 | + /// `bits` must hold a valid `Option<T>` bit-pattern in its low |
| 70 | + /// `size_of::<Option<T>>()` bytes, as produced by a previous `encode`. |
| 71 | + const unsafe fn decode(bits: u64) -> Option<T> { |
| 72 | + let mut result = MaybeUninit::<Option<T>>::uninit(); |
| 73 | + ptr::copy_nonoverlapping( |
| 74 | + ptr::from_ref(&bits).cast::<u8>(), |
| 75 | + result.as_mut_ptr().cast::<u8>(), |
| 76 | + size_of::<Option<T>>(), |
| 77 | + ); |
| 78 | + result.assume_init() |
| 79 | + } |
| 80 | + |
| 81 | + /// Atomically replace the stored value with `None` and return what was there. |
| 82 | + /// Returns `None` if the value was already taken. |
| 83 | + pub fn take(&self) -> Option<T> { |
| 84 | + let old = self.atomic_swap(Self::encode(None)); |
| 85 | + // SAFETY: `old` holds a valid `Option<T>` bit-pattern. |
| 86 | + unsafe { Self::decode(old) } |
| 87 | + } |
| 88 | + |
| 89 | + /// Atomically store `val`, dropping any previous value. |
| 90 | + pub fn set(&self, val: Option<T>) -> Option<T> { |
| 91 | + let old = self.atomic_swap(Self::encode(val)); |
| 92 | + unsafe { Self::decode(old) } |
| 93 | + } |
| 94 | + |
| 95 | + /// Atomically store `Some(val)`, returning the previous value. |
| 96 | + pub fn replace(&self, val: T) -> Option<T> { |
| 97 | + self.set(Some(val)) |
| 98 | + } |
| 99 | + |
| 100 | + /// Borrow the current value without taking it. |
| 101 | + /// |
| 102 | + /// # Safety |
| 103 | + /// Must not be called concurrently with [`take`], [`set`], or [`replace`]. |
| 104 | + pub unsafe fn as_option(&self) -> &Option<T> { |
| 105 | + &*self.0.get() |
| 106 | + } |
| 107 | +} |
| 108 | + |
| 109 | +impl<T> From<Option<T>> for AtomicOption<T> { |
| 110 | + fn from(val: Option<T>) -> Self { |
| 111 | + // we may raise this to 16 once AtomicU128 becomes stable |
| 112 | + debug_assert!( |
| 113 | + size_of::<Option<T>>() <= size_of::<u64>(), |
| 114 | + "AtomicOption requires size_of::<Option<T>>() <= 8, got {}", |
| 115 | + size_of::<Option<T>>() |
| 116 | + ); |
| 117 | + Self(UnsafeCell::new(val)) |
| 118 | + } |
| 119 | +} |
| 120 | + |
| 121 | +// `AtomicOption<T>` is `Send`/`Sync` when `T: Send` — same contract as `Mutex<Option<T>>`. |
| 122 | +unsafe impl<T: Send> Send for AtomicOption<T> {} |
| 123 | +unsafe impl<T: Send> Sync for AtomicOption<T> {} |
0 commit comments