Skip to content

Commit b6712cf

Browse files
committed
faster true count
1 parent f4339ca commit b6712cf

2 files changed

Lines changed: 270 additions & 5 deletions

File tree

vortex-buffer/benches/vortex_bitbuffer.rs

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,24 @@ impl FromIterator<bool> for Arrow<BooleanBuffer> {
2525
}
2626

2727
const INPUT_SIZE: &[usize] = &[128, 1024, 2048, 16_384, 65_536];
28+
const TRUE_COUNT_INPUT_SIZE: &[usize] = &[128, 1024, 16_384, 65_536, 1_048_576];
29+
const TRUE_COUNT_SLICE_CASES: &[(usize, usize)] = &[
30+
(128, 1),
31+
(128, 7),
32+
(1_024, 1),
33+
(1_024, 7),
34+
(16_384, 1),
35+
(16_384, 7),
36+
(65_536, 1),
37+
(65_536, 7),
38+
(1_048_576, 1),
39+
(1_048_576, 7),
40+
];
41+
42+
#[inline]
43+
fn true_count_pattern(i: usize) -> bool {
44+
(i.is_multiple_of(3)) ^ (i.is_multiple_of(11))
45+
}
2846

2947
#[cfg(not(codspeed))]
3048
#[divan::bench(args = INPUT_SIZE)]
@@ -158,22 +176,46 @@ fn slice_arrow_buffer(bencher: Bencher, length: usize) {
158176
});
159177
}
160178

161-
#[divan::bench(args = INPUT_SIZE)]
179+
#[divan::bench(args = TRUE_COUNT_INPUT_SIZE)]
162180
fn true_count_vortex_buffer(bencher: Bencher, length: usize) {
163-
let buffer = BitBuffer::from_iter((0..length).map(|i| i % 2 == 0));
181+
let buffer = BitBuffer::from_iter((0..length).map(true_count_pattern));
164182
bencher
165183
.with_inputs(|| &buffer)
166184
.bench_refs(|buffer| buffer.true_count())
167185
}
168186

169-
#[divan::bench(args = INPUT_SIZE)]
187+
#[divan::bench(args = TRUE_COUNT_INPUT_SIZE)]
170188
fn true_count_arrow_buffer(bencher: Bencher, length: usize) {
171-
let buffer = Arrow(BooleanBuffer::from_iter((0..length).map(|i| i % 2 == 0)));
189+
let buffer = Arrow(BooleanBuffer::from_iter(
190+
(0..length).map(true_count_pattern),
191+
));
172192
bencher
173193
.with_inputs(|| &buffer)
174194
.bench_refs(|buffer| buffer.0.count_set_bits());
175195
}
176196

197+
#[divan::bench(args = TRUE_COUNT_SLICE_CASES)]
198+
fn true_count_vortex_buffer_sliced(bencher: Bencher, (length, offset): (usize, usize)) {
199+
let buffer = BitBuffer::from_iter((0..length + offset).map(true_count_pattern));
200+
let sliced = buffer.slice(offset..offset + length);
201+
202+
bencher
203+
.with_inputs(|| &sliced)
204+
.bench_refs(|buffer| buffer.true_count())
205+
}
206+
207+
#[divan::bench(args = TRUE_COUNT_SLICE_CASES)]
208+
fn true_count_arrow_buffer_sliced(bencher: Bencher, (length, offset): (usize, usize)) {
209+
let buffer = Arrow(BooleanBuffer::from_iter(
210+
(0..length + offset).map(true_count_pattern),
211+
));
212+
let sliced = Arrow(buffer.0.slice(offset, length));
213+
214+
bencher
215+
.with_inputs(|| &sliced)
216+
.bench_refs(|buffer| buffer.0.count_set_bits());
217+
}
218+
177219
#[divan::bench(args = INPUT_SIZE)]
178220
fn bitwise_and_vortex_buffer(bencher: Bencher, length: usize) {
179221
let a = BitBuffer::from_iter((0..length).map(|i| i % 2 == 0));

vortex-buffer/src/bit/buf.rs

Lines changed: 224 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
5+
use std::arch::is_x86_feature_detected;
46
use std::fmt::Display;
57
use std::fmt::Formatter;
68
use std::fmt::Result as FmtResult;
@@ -316,7 +318,7 @@ impl BitBuffer {
316318

317319
/// Get the number of set bits in the buffer.
318320
pub fn true_count(&self) -> usize {
319-
self.unaligned_chunks().count_ones()
321+
true_count_impl(self.buffer.as_slice(), self.offset, self.len)
320322
}
321323

322324
/// Get the number of unset bits in the buffer.
@@ -356,6 +358,198 @@ impl BitBuffer {
356358
}
357359
}
358360

361+
#[inline]
362+
fn true_count_impl(bytes: &[u8], offset: usize, len: usize) -> usize {
363+
let Some((head, middle, tail)) = byte_aligned_region(bytes, offset, len) else {
364+
return 0;
365+
};
366+
367+
let mut count = head.map_or(0, |v| v.count_ones() as usize);
368+
369+
if !middle.is_empty() {
370+
count += count_aligned_bytes(middle);
371+
}
372+
373+
count + tail.map_or(0, |v| v.count_ones() as usize)
374+
}
375+
376+
#[inline]
377+
fn byte_aligned_region(
378+
bytes: &[u8],
379+
offset: usize,
380+
len: usize,
381+
) -> Option<(Option<u8>, &[u8], Option<u8>)> {
382+
if len == 0 {
383+
return None;
384+
}
385+
386+
let start_byte = offset / 8;
387+
let start_bit = offset % 8;
388+
let end_bit = offset + len;
389+
let end_byte = end_bit / 8;
390+
let head = (start_bit != 0).then(|| {
391+
let head_len = (8 - start_bit).min(len);
392+
mask_partial_byte(bytes[start_byte], start_bit, head_len)
393+
});
394+
395+
let middle_start = start_byte + usize::from(start_bit != 0);
396+
let middle_end = end_byte;
397+
let middle = if middle_start < middle_end {
398+
&bytes[middle_start..middle_end]
399+
} else {
400+
&[]
401+
};
402+
403+
let consumed = if start_bit != 0 {
404+
(8 - start_bit).min(len)
405+
} else {
406+
0
407+
} + middle.len() * 8;
408+
let tail_len = len - consumed;
409+
let tail = (tail_len != 0).then(|| mask_partial_byte(bytes[middle_end], 0, tail_len));
410+
411+
Some((head, middle, tail))
412+
}
413+
414+
#[inline]
415+
fn mask_partial_byte(byte: u8, bit_offset: usize, bit_len: usize) -> u8 {
416+
debug_assert!(bit_offset < 8);
417+
debug_assert!(bit_len <= 8 - bit_offset);
418+
419+
let shifted = byte >> bit_offset;
420+
let mask = if bit_len == 8 {
421+
u8::MAX
422+
} else {
423+
(1u8 << bit_len) - 1
424+
};
425+
426+
shifted & mask
427+
}
428+
429+
#[inline]
430+
fn count_aligned_bytes(bytes: &[u8]) -> usize {
431+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
432+
{
433+
if bytes.len() >= 64
434+
&& is_x86_feature_detected!("avx512f")
435+
&& is_x86_feature_detected!("avx512vpopcntdq")
436+
{
437+
// SAFETY: Runtime detection guarantees the required target features.
438+
return unsafe { count_aligned_bytes_avx512(bytes) };
439+
}
440+
441+
if bytes.len() >= 32 && is_x86_feature_detected!("avx2") {
442+
// SAFETY: Runtime detection guarantees the required target features.
443+
return unsafe { count_aligned_bytes_avx2(bytes) };
444+
}
445+
}
446+
447+
count_aligned_bytes_scalar(bytes)
448+
}
449+
450+
#[inline]
451+
fn count_aligned_bytes_scalar(bytes: &[u8]) -> usize {
452+
let (words, tail) = bytes.as_chunks::<8>();
453+
let mut count = words
454+
.iter()
455+
.map(|word| u64::from_le_bytes(*word).count_ones() as usize)
456+
.sum::<usize>();
457+
458+
count += tail
459+
.iter()
460+
.map(|byte| byte.count_ones() as usize)
461+
.sum::<usize>();
462+
463+
count
464+
}
465+
466+
#[cfg(target_arch = "x86_64")]
467+
#[target_feature(enable = "avx2")]
468+
unsafe fn count_aligned_bytes_avx2(bytes: &[u8]) -> usize {
469+
use std::arch::x86_64::__m256i;
470+
use std::arch::x86_64::_mm256_add_epi8;
471+
use std::arch::x86_64::_mm256_add_epi64;
472+
use std::arch::x86_64::_mm256_and_si256;
473+
use std::arch::x86_64::_mm256_loadu_si256;
474+
use std::arch::x86_64::_mm256_sad_epu8;
475+
use std::arch::x86_64::_mm256_set1_epi8;
476+
use std::arch::x86_64::_mm256_setr_epi8;
477+
use std::arch::x86_64::_mm256_setzero_si256;
478+
use std::arch::x86_64::_mm256_shuffle_epi8;
479+
use std::arch::x86_64::_mm256_srli_epi16;
480+
use std::arch::x86_64::_mm256_storeu_si256;
481+
482+
#[inline]
483+
unsafe fn byte_popcount(chunk: __m256i, mask: __m256i, lookup: __m256i) -> __m256i {
484+
let lo = unsafe { _mm256_and_si256(chunk, mask) };
485+
let hi = unsafe { _mm256_and_si256(_mm256_srli_epi16(chunk, 4), mask) };
486+
unsafe {
487+
_mm256_add_epi8(
488+
_mm256_shuffle_epi8(lookup, lo),
489+
_mm256_shuffle_epi8(lookup, hi),
490+
)
491+
}
492+
}
493+
494+
let lookup = _mm256_setr_epi8(
495+
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3,
496+
3, 4,
497+
);
498+
let mask = _mm256_set1_epi8(0x0f);
499+
let zero = _mm256_setzero_si256();
500+
let mut accum = _mm256_setzero_si256();
501+
let mut index = 0;
502+
503+
while index + 128 <= bytes.len() {
504+
for lane in 0..4 {
505+
let ptr = unsafe { bytes.as_ptr().add(index + lane * 32) }.cast::<__m256i>();
506+
let chunk = unsafe { _mm256_loadu_si256(ptr) };
507+
let counts = unsafe { byte_popcount(chunk, mask, lookup) };
508+
accum = _mm256_add_epi64(accum, _mm256_sad_epu8(counts, zero));
509+
}
510+
index += 128;
511+
}
512+
513+
while index + 32 <= bytes.len() {
514+
let ptr = unsafe { bytes.as_ptr().add(index) }.cast::<__m256i>();
515+
let chunk = unsafe { _mm256_loadu_si256(ptr) };
516+
let counts = unsafe { byte_popcount(chunk, mask, lookup) };
517+
accum = _mm256_add_epi64(accum, _mm256_sad_epu8(counts, zero));
518+
index += 32;
519+
}
520+
521+
let mut lanes = [0u64; 4];
522+
unsafe { _mm256_storeu_si256(lanes.as_mut_ptr().cast::<__m256i>(), accum) };
523+
524+
lanes.iter().sum::<u64>() as usize + count_aligned_bytes_scalar(&bytes[index..])
525+
}
526+
527+
#[cfg(target_arch = "x86_64")]
528+
#[target_feature(enable = "avx512f,avx512vpopcntdq")]
529+
unsafe fn count_aligned_bytes_avx512(bytes: &[u8]) -> usize {
530+
use std::arch::x86_64::__m512i;
531+
use std::arch::x86_64::_mm512_add_epi64;
532+
use std::arch::x86_64::_mm512_loadu_si512;
533+
use std::arch::x86_64::_mm512_popcnt_epi64;
534+
use std::arch::x86_64::_mm512_setzero_si512;
535+
use std::arch::x86_64::_mm512_storeu_si512;
536+
537+
let mut accum = _mm512_setzero_si512();
538+
let mut index = 0;
539+
540+
while index + 64 <= bytes.len() {
541+
let ptr = unsafe { bytes.as_ptr().add(index) }.cast::<__m512i>();
542+
let chunk = unsafe { _mm512_loadu_si512(ptr) };
543+
accum = _mm512_add_epi64(accum, _mm512_popcnt_epi64(chunk));
544+
index += 64;
545+
}
546+
547+
let mut lanes = [0u64; 8];
548+
unsafe { _mm512_storeu_si512(lanes.as_mut_ptr().cast::<__m512i>(), accum) };
549+
550+
lanes.iter().sum::<u64>() as usize + count_aligned_bytes_scalar(&bytes[index..])
551+
}
552+
359553
// Conversions
360554

361555
impl BitBuffer {
@@ -784,4 +978,33 @@ mod tests {
784978
assert_eq!(mapped.value(i), expected, "Mismatch at index {}", i);
785979
}
786980
}
981+
982+
#[rstest]
983+
fn test_true_count_matches_iteration_for_slices(
984+
#[values(
985+
0usize, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
986+
23, 24, 25, 26, 27, 28, 29, 30
987+
)]
988+
offset: usize,
989+
#[values(
990+
0usize, 1, 2, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 255, 256, 257, 513
991+
)]
992+
slice_len: usize,
993+
) {
994+
let len = 513;
995+
let buf = BitBuffer::collect_bool(len + 31, |i| (i % 3 == 0) ^ (i % 11 == 0));
996+
997+
if offset + slice_len > buf.len() {
998+
return;
999+
}
1000+
1001+
let sliced = buf.slice(offset..offset + slice_len);
1002+
let expected = sliced.iter().filter(|bit| *bit).count();
1003+
1004+
assert_eq!(
1005+
sliced.true_count(),
1006+
expected,
1007+
"offset={offset} len={slice_len}"
1008+
);
1009+
}
7871010
}

0 commit comments

Comments
 (0)