Skip to content

Commit cfbbb6d

Browse files
committed
add size limit option - unlimited size performs the same, and backpressure can be modeled with size limit
1 parent 1e848ff commit cfbbb6d

8 files changed

Lines changed: 186 additions & 24 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

benchmarks/benches/channel.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,54 @@ fn mpsc_or_queues(criterion: &mut Criterion) {
120120
elapsed
121121
});
122122
});
123+
124+
group.bench_function("spillway_batch_limited", |bencher| {
125+
bencher
126+
.to_async(runtime(threads))
127+
.iter_custom(async |size| {
128+
let (send, mut receive) =
129+
spillway::channel_with_capacity_and_concurrency(65536, threads);
130+
let receiver = tokio::spawn(async move {
131+
let mut i = 0;
132+
while let Some(v) = receive.next().await {
133+
i += v;
134+
}
135+
i
136+
});
137+
138+
for _ in 0..threads {
139+
let send = send.clone();
140+
tokio::spawn(async move {
141+
let per_thread = (size as usize / threads).max(1);
142+
let batches = per_thread.div_ceil(32);
143+
for _ in 0..batches {
144+
let mut pending = 0_usize..32;
145+
loop {
146+
match send.send_many(pending) {
147+
Ok(()) => break,
148+
Err(spillway::Error::Full(returned)) => {
149+
pending = returned;
150+
tokio::task::yield_now().await;
151+
}
152+
Err(spillway::Error::Closed(_)) => return,
153+
}
154+
}
155+
}
156+
});
157+
}
158+
drop(send);
159+
160+
let start = Instant::now();
161+
let n = receiver.await;
162+
let elapsed = start.elapsed();
163+
log::info!("ok {n:?}");
164+
let per_thread = (size as usize / threads).max(1);
165+
let batches = per_thread.div_ceil(32);
166+
let expected: usize = (0..32).sum::<usize>() * batches * threads;
167+
assert_eq!(n.expect("must join successfully"), expected);
168+
elapsed
169+
});
170+
});
123171
}
124172

125173
fn runtime(threads: usize) -> tokio::runtime::Runtime {

spillway/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ categories.workspace = true
1313
futures = { workspace = true }
1414
log = { workspace = true }
1515
rand = { workspace = true }
16+
thiserror = { workspace = true }
1617

1718
[dev-dependencies]
1819
tokio-test = { workspace = true }

spillway/src/error.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
/// Errors returned by Spillway senders.
2+
///
3+
/// The unsent value(s) are returned in the error variant so the caller can
4+
/// reuse or drop them as appropriate.
5+
#[derive(Debug, thiserror::Error)]
6+
pub enum Error<T> {
7+
/// The channel is at or above its soft capacity. Nothing was enqueued.
8+
#[error("spillway channel is full")]
9+
Full(T),
10+
/// The Receiver has been dropped. The channel will never accept more values.
11+
#[error("spillway channel is closed")]
12+
Closed(T),
13+
}

spillway/src/lib.rs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
#![deny(missing_docs)]
22
#![doc = include_str!("../README.md")]
33

4+
mod error;
45
mod receiver;
56
mod sender;
67
mod shared;
78

89
use std::sync::Arc;
910

11+
pub use error::Error;
1012
pub use receiver::Receiver;
1113
pub use sender::Sender;
1214

13-
/// Get a new spillway channel with a default concurrency level.
15+
/// Get a new spillway channel with a default concurrency level and no capacity limit.
1416
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
1517
// const PARALLELISM: std::sync::LazyLock<usize> = std::sync::LazyLock::new(|| {
1618
// std::thread::available_parallelism()
@@ -21,14 +23,30 @@ pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
2123
channel_with_concurrency(8)
2224
}
2325

24-
/// Get a new spillway channel with the given concurrency level.
26+
/// Get a new spillway channel with the given concurrency level and no capacity limit.
2527
///
2628
/// Use this when you need lots of parallelism, or when you know how many Senders
2729
/// you will have. Higher numbers reduce contention, but increase the cost of
2830
/// parking the Receiver when idle. Thread count is a good starting point for
2931
/// concurrency.
3032
pub fn channel_with_concurrency<T>(concurrency: usize) -> (Sender<T>, Receiver<T>) {
31-
let shared = Arc::new(shared::Shared::new(concurrency));
33+
channel_with_capacity_and_concurrency(u64::MAX, concurrency)
34+
}
35+
36+
/// Get a new spillway channel with a soft capacity limit and the given concurrency level.
37+
///
38+
/// `capacity` is an upper bound on the number of in-flight values. Sends are
39+
/// rejected with [`Error::Full`] when the channel is at or above this limit.
40+
///
41+
/// Mind your batch sizes when using a capacity limit. If you have a capacity of 10 and you
42+
/// send 11 values in a batch, the entire batch will be rejected.
43+
///
44+
/// Pass `u64::MAX` to disable the limit (this matches [`channel_with_concurrency`]).
45+
pub fn channel_with_capacity_and_concurrency<T>(
46+
capacity: u64,
47+
concurrency: usize,
48+
) -> (Sender<T>, Receiver<T>) {
49+
let shared = Arc::new(shared::Shared::new(concurrency, capacity));
3250
let sender = Sender::new(shared.clone());
3351
let receiver = Receiver::new(shared);
3452

spillway/src/receiver.rs

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use std::{collections::VecDeque, sync::Arc};
1+
use std::{
2+
collections::VecDeque,
3+
sync::{atomic::AtomicU64, Arc},
4+
};
25

36
use crate::shared::Shared;
47

@@ -43,7 +46,10 @@ impl<T> Receiver<T> {
4346
/// * `Poll::Ready(None)` when all senders have been dropped and the Receiver is caught up. The Receiver will never receive more messages and you should drop it.
4447
pub fn poll_next(&mut self, context: &mut std::task::Context) -> std::task::Poll<Option<T>> {
4548
match self.buffer.pop_front() {
46-
Some(next) => std::task::Poll::Ready(Some(next)),
49+
Some(next) => {
50+
self.decrement_size(1);
51+
std::task::Poll::Ready(Some(next))
52+
}
4753
None => {
4854
let dirty_index = match self.shared.race_find_dirty(self.cursor) {
4955
Some(dirty_index) => {
@@ -89,11 +95,20 @@ impl<T> Receiver<T> {
8995
.buffer
9096
.pop_front()
9197
.expect("chutes are only dirty when they have contents");
98+
self.decrement_size(1);
9299
std::task::Poll::Ready(Some(next))
93100
}
94101
}
95102
}
96103

104+
fn decrement_size(&self, count: usize) {
105+
if self.shared.capacity != u64::MAX {
106+
self.shared
107+
.channel_size
108+
.fetch_sub(count as u64, std::sync::atomic::Ordering::Relaxed);
109+
}
110+
}
111+
97112
/// The next value for the Receiver.
98113
///
99114
/// * Some(T) is the next value.
@@ -116,11 +131,29 @@ impl<T> Receiver<T> {
116131
// we got one, but let's see if we can get more while we're here.
117132
// for convenience, we'll put the item back and drain the whole batch.
118133
self.buffer.push_front(next);
134+
// poll_next decremented channel_size by 1; put that one back. BatchDrain will
135+
// decrement once on Drop for however many items the caller consumes.
136+
if self.shared.capacity != u64::MAX {
137+
self.shared
138+
.channel_size
139+
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
140+
}
141+
let initial_len = self.buffer.len();
142+
let channel_size = if self.shared.capacity == u64::MAX {
143+
None
144+
} else {
145+
// SAFETY: same lifetime widening as `buffer` below — this borrow is into
146+
// `self.shared.channel_size`, which lives at least as long as 'a (it lives
147+
// as long as the Arc held by `self`).
148+
Some(unsafe { &*(&self.shared.channel_size as *const AtomicU64) })
149+
};
119150
// SAFETY: we have exclusive access to self for 'a. self itself is not referenced out from the fnmut, but the buffer is, which is
120151
// causing some borrow checker consternation. But since the buffer mutable borrow cannot outlive 'a, and &mut self can't
121152
// outlive 'a either, the borrow of buffer should be sound for 'a.
122153
std::task::Poll::Ready(Some(BatchDrain {
123154
buffer: unsafe { &mut *(&mut self.buffer as *mut _) },
155+
channel_size,
156+
initial_len,
124157
}))
125158
}
126159
std::task::Poll::Ready(None) => std::task::Poll::Ready(None),
@@ -133,6 +166,8 @@ impl<T> Receiver<T> {
133166

134167
struct BatchDrain<'a, T> {
135168
buffer: &'a mut VecDeque<T>,
169+
channel_size: Option<&'a AtomicU64>,
170+
initial_len: usize,
136171
}
137172
impl<T> Iterator for BatchDrain<'_, T> {
138173
type Item = T;
@@ -153,6 +188,17 @@ impl<T> ExactSizeIterator for BatchDrain<'_, T> {
153188
}
154189
}
155190

191+
impl<T> Drop for BatchDrain<'_, T> {
192+
fn drop(&mut self) {
193+
if let Some(channel_size) = self.channel_size {
194+
let consumed = self.initial_len - self.buffer.len();
195+
if consumed != 0 {
196+
channel_size.fetch_sub(consumed as u64, std::sync::atomic::Ordering::Relaxed);
197+
}
198+
}
199+
}
200+
}
201+
156202
#[cfg(test)]
157203
mod test {
158204
use std::task::{Context, Poll, Waker};

spillway/src/sender.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::sync::Arc;
22

3-
use crate::shared::Shared;
3+
use crate::{shared::Shared, Error};
44

55
/// The sending half of a Spillway channel.
66
///
@@ -56,7 +56,10 @@ impl<T> Sender<T> {
5656
/// However, you might receive 1, 4, 5, 2, 3, 6 or any other interleaving. But
5757
/// 1 will always appear before 2, and 2 before 3; and 4 will always appear before 5,
5858
/// and 5 before 6.
59-
pub fn send(&self, value: T) -> Result<(), T> {
59+
///
60+
/// Returns [`Error::Full`] if the channel has reached its capacity limit, or
61+
/// [`Error::Closed`] if the Receiver has been dropped.
62+
pub fn send(&self, value: T) -> Result<(), Error<T>> {
6063
self.shared.send(self.chute, value)
6164
}
6265

@@ -77,7 +80,15 @@ impl<T> Sender<T> {
7780
/// | 1, 2, 3, 4, 5, 6 |
7881
/// | 4, 5, 1, 2, 3, 6 |
7982
/// | 4, 5, 6, 1, 2, 3 |
80-
pub fn send_many<I: IntoIterator<Item = T>>(&self, values: I) -> Result<(), I> {
83+
///
84+
/// Returns [`Error::Full`] if the batch would push the channel past its capacity
85+
/// limit (the entire batch is rejected; partial enqueues never happen), or
86+
/// [`Error::Closed`] if the Receiver has been dropped.
87+
pub fn send_many<I>(&self, values: I) -> Result<(), Error<I::IntoIter>>
88+
where
89+
I: IntoIterator<Item = T>,
90+
I::IntoIter: ExactSizeIterator,
91+
{
8192
self.shared.send_many(self.chute, values)
8293
}
8394
}

spillway/src/shared.rs

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
use std::{
22
collections::VecDeque,
33
sync::{
4-
atomic::{AtomicBool, AtomicUsize},
4+
atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
55
Mutex,
66
},
77
};
88

9+
use crate::Error;
10+
911
pub struct Chute<T> {
1012
queue: Mutex<VecDeque<T>>,
1113
clean: AtomicBool,
@@ -60,6 +62,8 @@ pub struct Shared<T> {
6062
pub(crate) waker: futures::task::AtomicWaker,
6163
pub(crate) senders: AtomicUsize,
6264
pub(crate) chute_clock: AtomicUsize,
65+
pub(crate) channel_size: AtomicU64,
66+
pub(crate) capacity: u64,
6367
dead: AtomicBool,
6468
}
6569

@@ -70,57 +74,77 @@ impl<T> std::fmt::Debug for Shared<T> {
7074
.field("waker", &self.waker)
7175
.field("senders", &self.senders)
7276
.field("chute_clock", &self.chute_clock)
77+
.field("channel_size", &self.channel_size)
78+
.field("capacity", &self.capacity)
7379
.field("dead", &self.dead)
7480
.finish()
7581
}
7682
}
7783

7884
impl<T> Shared<T> {
79-
pub fn new(concurrency: usize) -> Self {
85+
pub fn new(concurrency: usize, capacity: u64) -> Self {
8086
Self {
8187
chutes: (0..concurrency)
8288
.map(|_| Default::default())
8389
.collect::<Vec<_>>(),
8490
waker: futures::task::AtomicWaker::new(),
8591
senders: AtomicUsize::new(0),
8692
chute_clock: AtomicUsize::new(0),
93+
channel_size: AtomicU64::new(0),
94+
capacity,
8795
dead: AtomicBool::new(false),
8896
}
8997
}
9098

9199
pub fn add_sender(&self) {
92-
self.senders
93-
.fetch_add(1, std::sync::atomic::Ordering::Release);
100+
self.senders.fetch_add(1, Ordering::Release);
94101
}
95102

96103
pub fn drop_sender(&self) -> usize {
97-
self.senders
98-
.fetch_sub(1, std::sync::atomic::Ordering::AcqRel)
104+
self.senders.fetch_sub(1, Ordering::AcqRel)
99105
}
100106

101107
pub fn choose_chute(&self) -> usize {
102-
self.chute_clock
103-
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
104-
% self.chutes.len()
108+
self.chute_clock.fetch_add(1, Ordering::Relaxed) % self.chutes.len()
105109
}
106110

107111
pub fn wake(&self) {
108112
self.waker.wake();
109113
}
110114

111-
pub fn send(&self, chute: usize, value: T) -> Result<(), T> {
112-
if self.dead.load(std::sync::atomic::Ordering::Relaxed) {
113-
return Err(value);
115+
pub fn send(&self, chute: usize, value: T) -> Result<(), Error<T>> {
116+
if self.dead.load(Ordering::Relaxed) {
117+
return Err(Error::Closed(value));
118+
}
119+
if self.capacity != u64::MAX {
120+
let prev = self.channel_size.fetch_add(1, Ordering::Relaxed);
121+
if self.capacity < prev + 1 {
122+
self.channel_size.fetch_sub(1, Ordering::Relaxed);
123+
return Err(Error::Full(value));
124+
}
114125
}
115126
self.send_many_infallible(chute, [value]);
116127
Ok(())
117128
}
118129

119-
pub fn send_many<I: IntoIterator<Item = T>>(&self, chute: usize, values: I) -> Result<(), I> {
120-
if self.dead.load(std::sync::atomic::Ordering::Relaxed) {
121-
return Err(values);
130+
pub fn send_many<I>(&self, chute: usize, values: I) -> Result<(), Error<I::IntoIter>>
131+
where
132+
I: IntoIterator<Item = T>,
133+
I::IntoIter: ExactSizeIterator,
134+
{
135+
let iter = values.into_iter();
136+
if self.dead.load(Ordering::Relaxed) {
137+
return Err(Error::Closed(iter));
138+
}
139+
if self.capacity != u64::MAX {
140+
let count = iter.len() as u64;
141+
let prev = self.channel_size.fetch_add(count, Ordering::Relaxed);
142+
if self.capacity < prev + count {
143+
self.channel_size.fetch_sub(count, Ordering::Relaxed);
144+
return Err(Error::Full(iter));
145+
}
122146
}
123-
self.send_many_infallible(chute, values);
147+
self.send_many_infallible(chute, iter);
124148
Ok(())
125149
}
126150

0 commit comments

Comments
 (0)