diff --git a/CHANGELOG.md b/CHANGELOG.md index d21f26cbb7..2a0a5397d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,9 @@ of panicking drop implementations. - Added `CString::{into_bytes, into_bytes_with_nul, into_string}` - Added `pop_front_if` and `pop_back_if` to `Deque` - Made `Vec::from_array` const. +- Fixed long division being instroduced by the const-erasure in spsc +- spsc: Fix integer overflow in iterators when N > usize::MAX/2 and the queue loops. +- spsc: Fix integer overflow leading to a panic in `len` when N == usize::MAX and debug assertions are enabled. ## [v0.9.2] 2025-11-12 diff --git a/src/spsc.rs b/src/spsc.rs index 6eb6fde4d7..25b31a6f3a 100644 --- a/src/spsc.rs +++ b/src/spsc.rs @@ -129,12 +129,6 @@ pub struct QueueInner { /// A statically allocated single-producer, single-consumer queue with a capacity of `N - 1` /// elements. /// -///
-/// -/// To get better performance, use a value for `N` that is a power of 2. -/// -///
-/// /// You will likely want to use [`split`](QueueInner::split) to create a producer-consumer pair. pub type Queue = QueueInner>; @@ -182,7 +176,15 @@ impl QueueInner { #[inline] fn increment(&self, val: usize) -> usize { - (val + 1) % self.n() + // We know that self.n() <= usize::MAX + // So this can only overflow if N == usize::MAX + // and in this case the overflow will be equivalent to the modulo N operation + let val = val.wrapping_add(1); + if val >= self.n() { + val - self.n() + } else { + val + } } #[inline] @@ -202,10 +204,13 @@ impl QueueInner { let current_head = self.head.load(Ordering::Relaxed); let current_tail = self.tail.load(Ordering::Relaxed); - current_tail - .wrapping_sub(current_head) - .wrapping_add(self.n()) - % self.n() + if current_tail >= current_head { + current_tail - current_head + } else { + current_tail + .wrapping_sub(current_head) + .wrapping_add(self.n()) + } } /// Returns whether the queue is empty. @@ -626,7 +631,11 @@ impl<'a, T> Iterator for Iter<'a, T> { if self.index < self.len { let head = self.rb.head.load(Ordering::Relaxed); - let i = (head + self.index) % self.rb.n(); + let i = match head.checked_add(self.index) { + Some(i) if i >= self.rb.n() => i - self.rb.n(), + Some(i) => i, + None => head.wrapping_add(self.index).wrapping_sub(self.rb.n()), + }; self.index += 1; Some(unsafe { &*(self.rb.buffer.borrow().get_unchecked(i).get() as *const T) }) @@ -643,7 +652,11 @@ impl<'a, T> Iterator for IterMut<'a, T> { if self.index < self.len { let head = self.rb.head.load(Ordering::Relaxed); - let i = (head + self.index) % self.rb.n(); + let i = match head.checked_add(self.index) { + Some(i) if i >= self.rb.n() => i - self.rb.n(), + Some(i) => i, + None => head.wrapping_add(self.index).wrapping_sub(self.rb.n()), + }; self.index += 1; Some(unsafe { &mut *self.rb.buffer.borrow().get_unchecked(i).get().cast::() }) @@ -659,7 +672,11 @@ impl DoubleEndedIterator for Iter<'_, T> { let head = self.rb.head.load(Ordering::Relaxed); // self.len > 0, since it's larger than self.index > 0 - let i = (head + self.len - 1) % self.rb.n(); + let i = match head.checked_add(self.len - 1) { + Some(i) if i >= self.rb.n() => i - self.rb.n(), + Some(i) => i, + None => head.wrapping_add(self.len - 1).wrapping_sub(self.rb.n()), + }; self.len -= 1; Some(unsafe { &*(self.rb.buffer.borrow().get_unchecked(i).get() as *const T) }) } else { @@ -674,7 +691,11 @@ impl DoubleEndedIterator for IterMut<'_, T> { let head = self.rb.head.load(Ordering::Relaxed); // self.len > 0, since it's larger than self.index > 0 - let i = (head + self.len - 1) % self.rb.n(); + let i = match head.checked_add(self.len - 1) { + Some(i) if i >= self.rb.n() => i - self.rb.n(), + Some(i) => i, + None => head.wrapping_add(self.len - 1).wrapping_sub(self.rb.n()), + }; self.len -= 1; Some(unsafe { &mut *self.rb.buffer.borrow().get_unchecked(i).get().cast::() }) } else { @@ -882,9 +903,19 @@ impl Producer<'_, T> { #[cfg(test)] mod tests { - use std::hash::{Hash, Hasher}; + use std::{ + cell::UnsafeCell, + hash::{Hash, Hasher}, + mem::MaybeUninit, + }; use super::{Consumer, Producer, Queue}; + #[cfg(not(feature = "portable-atomic"))] + use core::sync::atomic; + #[cfg(feature = "portable-atomic")] + use portable_atomic as atomic; + + use atomic::AtomicUsize; use static_assertions::assert_not_impl_any; @@ -1076,6 +1107,28 @@ mod tests { assert_eq!(items.next(), None); } + /// Exercise the modulo `self.n()` operation in `next()` + #[test] + fn iter_modulo() { + let mut rb: Queue = Queue::new(); + + for _ in 0..2 { + rb.enqueue(0).unwrap(); + rb.dequeue().unwrap(); + } + rb.enqueue(1).unwrap(); + rb.enqueue(2).unwrap(); + rb.enqueue(3).unwrap(); + + let mut items = rb.iter(); + + // assert_eq!(items.next(), Some(&0)); + assert_eq!(items.next(), Some(&1)); + assert_eq!(items.next(), Some(&2)); + assert_eq!(items.next(), Some(&3)); + assert_eq!(items.next(), None); + } + #[test] fn iter_double_ended() { let mut rb: Queue = Queue::new(); @@ -1093,6 +1146,28 @@ mod tests { assert_eq!(items.next_back(), None); } + /// Test that the modulo in `next_back` works as expected + #[test] + fn iter_double_ended_modulo() { + let mut rb: Queue = Queue::new(); + + for _ in 0..2 { + rb.enqueue(0).unwrap(); + rb.dequeue().unwrap(); + } + rb.enqueue(0).unwrap(); + rb.enqueue(1).unwrap(); + rb.enqueue(2).unwrap(); + + let mut items = rb.iter(); + + assert_eq!(items.next(), Some(&0)); + assert_eq!(items.next_back(), Some(&2)); + assert_eq!(items.next(), Some(&1)); + assert_eq!(items.next(), None); + assert_eq!(items.next_back(), None); + } + #[test] fn iter_mut() { let mut rb: Queue = Queue::new(); @@ -1126,6 +1201,28 @@ mod tests { assert_eq!(items.next_back(), None); } + /// Test that the modulo in `next_back` works as expected + #[test] + fn iter_mut_double_ended_modulo() { + let mut rb: Queue = Queue::new(); + + for _ in 0..2 { + rb.enqueue(0).unwrap(); + rb.dequeue().unwrap(); + } + rb.enqueue(0).unwrap(); + rb.enqueue(1).unwrap(); + rb.enqueue(2).unwrap(); + + let mut items = rb.iter_mut(); + + assert_eq!(items.next(), Some(&mut 0)); + assert_eq!(items.next_back(), Some(&mut 2)); + assert_eq!(items.next(), Some(&mut 1)); + assert_eq!(items.next(), None); + assert_eq!(items.next_back(), None); + } + #[test] fn wrap_around() { let mut rb: Queue = Queue::new(); @@ -1238,4 +1335,74 @@ mod tests { }; assert_eq!(hash1, hash2); } + + // Test for some integer overflow bugs. See + // https://github.com/rust-embedded/heapless/pull/652#discussion_r3046630717 + // for more info + #[test] + #[cfg_attr(miri, ignore)] // too slow + fn test_len_overflow() { + let mut queue = Queue::<(), { usize::MAX }> { + head: AtomicUsize::new(usize::MAX), + tail: AtomicUsize::new(2), + buffer: [const { UnsafeCell::new(MaybeUninit::new(())) }; usize::MAX], + }; + queue.enqueue(()).unwrap(); + queue.enqueue(()).unwrap(); + + let collected: Vec<_> = queue.iter().collect(); + assert_eq!(&collected, &[&(); 4]); + } + + #[test] + #[cfg_attr(miri, ignore)] // too slow + fn test_usize_overflow_iter() { + let queue = Queue::<(), { usize::MAX - 1 }> { + head: AtomicUsize::new(usize::MAX - 3), + tail: AtomicUsize::new(2), + buffer: [const { UnsafeCell::new(MaybeUninit::new(())) }; usize::MAX - 1], + }; + + let collected: Vec<_> = queue.iter().collect(); + assert_eq!(&collected, &[&(); 4]); + } + + #[test] + #[cfg_attr(miri, ignore)] // too slow + fn test_usize_overflow_iter_mut() { + let mut queue = Queue::<(), { usize::MAX - 1 }> { + head: AtomicUsize::new(usize::MAX - 3), + tail: AtomicUsize::new(2), + buffer: [const { UnsafeCell::new(MaybeUninit::new(())) }; usize::MAX - 1], + }; + + let collected: Vec<_> = queue.iter_mut().collect(); + assert_eq!(&collected, &[&(); 4]); + } + + #[test] + #[cfg_attr(miri, ignore)] // too slow + fn test_usize_overflow_iter_rev() { + let queue = Queue::<(), { usize::MAX - 1 }> { + head: AtomicUsize::new(usize::MAX - 3), + tail: AtomicUsize::new(2), + buffer: [const { UnsafeCell::new(MaybeUninit::new(())) }; usize::MAX - 1], + }; + + let collected: Vec<_> = queue.iter().rev().collect(); + assert_eq!(&collected, &[&(); 4]); + } + + #[test] + #[cfg_attr(miri, ignore)] // too slow + fn test_usize_overflow_iter_mut_rev() { + let mut queue = Queue::<(), { usize::MAX - 1 }> { + head: AtomicUsize::new(usize::MAX - 3), + tail: AtomicUsize::new(2), + buffer: [const { UnsafeCell::new(MaybeUninit::new(())) }; usize::MAX - 1], + }; + + let collected: Vec<_> = queue.iter_mut().rev().collect(); + assert_eq!(&collected, &[&(); 4]); + } }