Skip to content

Commit 5f41cc7

Browse files
committed
fix count bug in OrderedLocalQueue
1 parent 30280fd commit 5f41cc7

1 file changed

Lines changed: 41 additions & 30 deletions

File tree

core/src/common/ordered_work_steal.rs

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -181,22 +181,13 @@ impl<'l, T: Debug> OrderedLocalQueue<'l, T> {
181181
}
182182

183183
/// Returns `true` if the local queue is empty.
184-
///
185-
/// When the `len` counter indicates non-empty, this method verifies by
186-
/// scanning all Workers. The counter can become stale-high when sibling
187-
/// schedulers steal items via `stealer().steal()` without decrementing
188-
/// our counter.
184+
pub fn is_local_empty(&self) -> bool {
185+
self.local_len() == 0
186+
}
187+
188+
/// Returns `true` if all the queues are empty.
189189
pub fn is_empty(&self) -> bool {
190-
if self.len() == 0 {
191-
return true;
192-
}
193-
// len might be stale due to concurrent work-stealing — verify workers
194-
for entry in self.queue {
195-
if !entry.value().is_empty() {
196-
return false;
197-
}
198-
}
199-
true
190+
self.len() == 0
200191
}
201192

202193
/// Returns `true` if the local queue is full.
@@ -208,19 +199,19 @@ impl<'l, T: Debug> OrderedLocalQueue<'l, T> {
208199
///
209200
/// let queue = OrderedWorkStealQueue::new(1, 2);
210201
/// let local = queue.local_queue();
211-
/// assert!(local.is_empty());
202+
/// assert!(local.is_local_empty());
212203
/// for i in 0..2 {
213204
/// local.push_with_priority(i, i);
214205
/// }
215-
/// assert!(local.is_full());
206+
/// assert!(local.is_local_full());
216207
/// assert_eq!(local.pop(), Some(0));
217-
/// assert_eq!(local.len(), 1);
208+
/// assert_eq!(local.local_len(), 1);
218209
/// assert_eq!(local.pop(), Some(1));
219210
/// assert_eq!(local.pop(), None);
220-
/// assert!(local.is_empty());
211+
/// assert!(local.is_local_empty());
221212
/// ```
222-
pub fn is_full(&self) -> bool {
223-
self.len() >= self.shared.local_capacity
213+
pub fn is_local_full(&self) -> bool {
214+
self.local_len() >= self.shared.local_capacity
224215
}
225216

226217
fn max_steal(&self) -> usize {
@@ -229,11 +220,11 @@ impl<'l, T: Debug> OrderedLocalQueue<'l, T> {
229220
.local_capacity
230221
.saturating_add(1)
231222
.saturating_div(2)
232-
.saturating_sub(self.len())
223+
.saturating_sub(self.local_len())
233224
}
234225

235226
fn can_steal(&self) -> bool {
236-
self.len()
227+
self.local_len()
237228
< self
238229
.shared
239230
.local_capacity
@@ -242,10 +233,20 @@ impl<'l, T: Debug> OrderedLocalQueue<'l, T> {
242233
}
243234

244235
/// Returns the number of elements in the queue.
245-
pub fn len(&self) -> usize {
236+
pub fn local_len(&self) -> usize {
246237
self.len.load(Ordering::Acquire)
247238
}
248239

240+
/// Returns the number of elements in the all queues.
241+
pub fn len(&self) -> usize {
242+
let mut full_len = self.local_len() + self.shared.len();
243+
for entry in self.queue {
244+
let worker = entry.value();
245+
full_len += worker.capacity() - worker.spare_capacity();
246+
}
247+
full_len
248+
}
249+
249250
fn try_lock(&self) -> bool {
250251
self.stealing
251252
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
@@ -275,7 +276,7 @@ impl<'l, T: Debug> OrderedLocalQueue<'l, T> {
275276
/// assert_eq!(local.pop(), None);
276277
/// ```
277278
pub fn push_with_priority(&self, priority: c_longlong, item: T) {
278-
if self.is_full() {
279+
if self.is_local_full() {
279280
self.push_to_global(priority, item);
280281
return;
281282
}
@@ -289,20 +290,23 @@ impl<'l, T: Debug> OrderedLocalQueue<'l, T> {
289290
} else {
290291
//add count
291292
self.len
292-
.store(self.len().saturating_add(1), Ordering::Release);
293+
.store(self.local_len().saturating_add(1), Ordering::Release);
293294
}
294295
}
295296

296297
fn push_to_global(&self, priority: c_longlong, item: T) {
297298
//把本地队列的一半放到全局队列
298-
let count = self.len() / 2;
299+
let count = self.local_len() / 2;
299300
for _ in 0..count {
300301
for entry in self.queue.iter().rev() {
301302
if let Some(item) = entry.value().pop() {
302303
self.shared.push_with_priority(*entry.key(), item);
303304
}
304305
}
305306
}
307+
// refresh count
308+
self.len
309+
.store(self.len().saturating_sub(count), Ordering::Release);
306310
//直接放到全局队列
307311
self.shared.push_with_priority(priority, item);
308312
}
@@ -346,12 +350,12 @@ impl<'l, T: Debug> OrderedLocalQueue<'l, T> {
346350
/// for i in 2..6 {
347351
/// local0.push_with_priority(i, i);
348352
/// }
349-
/// assert_eq!(local0.len(), 4);
353+
/// assert_eq!(local0.local_len(), 4);
350354
/// let local1 = queue.local_queue();
351355
/// for i in 0..2 {
352356
/// local1.push_with_priority(i, i);
353357
/// }
354-
/// assert_eq!(local1.len(), 2);
358+
/// assert_eq!(local1.local_len(), 2);
355359
/// for i in 0..6 {
356360
/// assert_eq!(local1.pop(), Some(i));
357361
/// }
@@ -411,6 +415,13 @@ impl<'l, T: Debug> OrderedLocalQueue<'l, T> {
411415
})
412416
.is_ok()
413417
{
418+
// refresh local len
419+
self.len.store(
420+
self.local_len().saturating_add(
421+
into_queue.capacity() - into_queue.spare_capacity(),
422+
),
423+
Ordering::Release,
424+
);
414425
self.release_lock();
415426
return self.pop_local();
416427
}
@@ -429,7 +440,7 @@ impl<'l, T: Debug> OrderedLocalQueue<'l, T> {
429440
if let Some(val) = entry.value().pop() {
430441
// Decrement the count.
431442
self.len
432-
.store(self.len().saturating_sub(1), Ordering::Release);
443+
.store(self.local_len().saturating_sub(1), Ordering::Release);
433444
return Some(val);
434445
}
435446
}

0 commit comments

Comments
 (0)