Skip to content

Commit 0417cce

Browse files
committed
Faster true count using AVX intrinsics
Signed-off-by: Robert Kruszewski <github@robertk.io>
1 parent 0b981a8 commit 0417cce

4 files changed

Lines changed: 240 additions & 3 deletions

File tree

vortex-buffer/benches/vortex_bitbuffer.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ impl FromIterator<bool> for Arrow<BooleanBuffer> {
2626

2727
const INPUT_SIZE: &[usize] = &[128, 1024, 2048, 16_384, 65_536];
2828

29+
#[inline]
30+
fn true_count_pattern(i: usize) -> bool {
31+
(i.is_multiple_of(3)) ^ (i.is_multiple_of(11))
32+
}
33+
2934
#[cfg(not(codspeed))]
3035
#[divan::bench(args = INPUT_SIZE)]
3136
fn from_iter_arrow(n: usize) {
@@ -160,15 +165,17 @@ fn slice_arrow_buffer(bencher: Bencher, length: usize) {
160165

161166
#[divan::bench(args = INPUT_SIZE)]
162167
fn true_count_vortex_buffer(bencher: Bencher, length: usize) {
163-
let buffer = BitBuffer::from_iter((0..length).map(|i| i % 2 == 0));
168+
let buffer = BitBuffer::from_iter((0..length).map(true_count_pattern));
164169
bencher
165170
.with_inputs(|| &buffer)
166171
.bench_refs(|buffer| buffer.true_count())
167172
}
168173

169174
#[divan::bench(args = INPUT_SIZE)]
170175
fn true_count_arrow_buffer(bencher: Bencher, length: usize) {
171-
let buffer = Arrow(BooleanBuffer::from_iter((0..length).map(|i| i % 2 == 0)));
176+
let buffer = Arrow(BooleanBuffer::from_iter(
177+
(0..length).map(true_count_pattern),
178+
));
172179
bencher
173180
.with_inputs(|| &buffer)
174181
.bench_refs(|buffer| buffer.0.count_set_bits());

vortex-buffer/src/bit/buf.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use crate::bit::BitIndexIterator;
2121
use crate::bit::BitIterator;
2222
use crate::bit::BitSliceIterator;
2323
use crate::bit::UnalignedBitChunk;
24+
use crate::bit::count_ones::count_ones;
2425
use crate::bit::get_bit_unchecked;
2526
use crate::bit::ops::bitwise_binary_op;
2627
use crate::bit::ops::bitwise_unary_op;
@@ -316,7 +317,7 @@ impl BitBuffer {
316317

317318
/// Get the number of set bits in the buffer.
318319
pub fn true_count(&self) -> usize {
319-
self.unaligned_chunks().count_ones()
320+
count_ones(self.buffer.as_slice(), self.offset, self.len)
320321
}
321322

322323
/// Get the number of unset bits in the buffer.
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
#[cfg(target_arch = "x86_64")]
5+
use vortex_error::VortexExpect;
6+
7+
#[inline]
8+
pub fn count_ones(bytes: &[u8], offset: usize, len: usize) -> usize {
9+
if bytes.is_empty() {
10+
return 0;
11+
}
12+
13+
let (head, middle, tail) = align_offset_len(bytes, offset, len);
14+
15+
let mut count = head.map_or(0, |v| v.count_ones() as usize);
16+
17+
if !middle.is_empty() {
18+
count += count_ones_aligned(middle);
19+
}
20+
21+
count + tail.map_or(0, |v| v.count_ones() as usize)
22+
}
23+
24+
#[inline]
25+
fn align_offset_len(bytes: &[u8], offset: usize, len: usize) -> (Option<u8>, &[u8], Option<u8>) {
26+
let start_byte = offset / 8;
27+
let start_bit = offset % 8;
28+
let end_bit = offset + len;
29+
let end_byte = end_bit / 8;
30+
let head = (start_bit != 0).then(|| {
31+
let start_len = (8 - start_bit).min(len);
32+
mask_byte(bytes[start_byte], start_bit, start_len)
33+
});
34+
35+
let middle_start = start_byte + usize::from(start_bit != 0);
36+
let middle_end = end_byte;
37+
let middle = if middle_start < middle_end {
38+
&bytes[middle_start..middle_end]
39+
} else {
40+
&[]
41+
};
42+
43+
let consumed = if start_bit != 0 {
44+
(8 - start_bit).min(len)
45+
} else {
46+
0
47+
} + middle.len() * 8;
48+
let tail_len = len - consumed;
49+
let tail = (tail_len != 0).then(|| mask_byte(bytes[middle_end], 0, tail_len));
50+
51+
(head, middle, tail)
52+
}
53+
54+
#[inline]
55+
fn mask_byte(byte: u8, bit_offset: usize, bit_len: usize) -> u8 {
56+
debug_assert!(bit_offset < 8);
57+
debug_assert!(bit_len <= 8 - bit_offset);
58+
59+
let shifted = byte >> bit_offset;
60+
let mask = if bit_len == 8 {
61+
u8::MAX
62+
} else {
63+
(1u8 << bit_len) - 1
64+
};
65+
66+
shifted & mask
67+
}
68+
69+
#[inline]
70+
fn count_ones_aligned(bytes: &[u8]) -> usize {
71+
#[cfg(target_arch = "x86_64")]
72+
{
73+
if is_x86_feature_detected!("avx512f")
74+
&& is_x86_feature_detected!("avx512vpopcntdq")
75+
&& bytes.len() >= 64
76+
{
77+
// SAFETY: Runtime detection guarantees the required target features.
78+
return unsafe { count_ones_aligned_avx512(bytes) };
79+
}
80+
81+
if is_x86_feature_detected!("avx2") && bytes.len() >= 32 {
82+
// SAFETY: Runtime detection guarantees the required target features.
83+
return unsafe { count_ones_aligned_avx2(bytes) };
84+
}
85+
}
86+
87+
count_ones_aligned_scalar(bytes)
88+
}
89+
90+
#[inline]
91+
fn count_ones_aligned_scalar(bytes: &[u8]) -> usize {
92+
let (words, tail) = bytes.as_chunks::<8>();
93+
let count = words
94+
.iter()
95+
.map(|word| u64::from_le_bytes(*word).count_ones() as usize)
96+
.sum::<usize>();
97+
98+
count
99+
+ tail
100+
.iter()
101+
.map(|byte| byte.count_ones() as usize)
102+
.sum::<usize>()
103+
}
104+
105+
#[cfg(target_arch = "x86_64")]
106+
#[target_feature(enable = "avx2")]
107+
unsafe fn count_ones_aligned_avx2(bytes: &[u8]) -> usize {
108+
use std::arch::x86_64::__m256i;
109+
use std::arch::x86_64::_mm256_add_epi8;
110+
use std::arch::x86_64::_mm256_add_epi64;
111+
use std::arch::x86_64::_mm256_and_si256;
112+
use std::arch::x86_64::_mm256_loadu_si256;
113+
use std::arch::x86_64::_mm256_sad_epu8;
114+
use std::arch::x86_64::_mm256_set1_epi8;
115+
use std::arch::x86_64::_mm256_setr_epi8;
116+
use std::arch::x86_64::_mm256_setzero_si256;
117+
use std::arch::x86_64::_mm256_shuffle_epi8;
118+
use std::arch::x86_64::_mm256_srli_epi16;
119+
use std::arch::x86_64::_mm256_storeu_si256;
120+
121+
#[inline]
122+
unsafe fn byte_popcount(chunk: __m256i, mask: __m256i, lookup: __m256i) -> __m256i {
123+
let lo = unsafe { _mm256_and_si256(chunk, mask) };
124+
let hi = unsafe { _mm256_and_si256(_mm256_srli_epi16(chunk, 4), mask) };
125+
unsafe {
126+
_mm256_add_epi8(
127+
_mm256_shuffle_epi8(lookup, lo),
128+
_mm256_shuffle_epi8(lookup, hi),
129+
)
130+
}
131+
}
132+
133+
let lookup = _mm256_setr_epi8(
134+
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3,
135+
3, 4,
136+
);
137+
let mask = _mm256_set1_epi8(0x0f);
138+
let zero = _mm256_setzero_si256();
139+
let mut accum = _mm256_setzero_si256();
140+
let mut index = 0;
141+
142+
while index + 128 <= bytes.len() {
143+
for lane in 0..4 {
144+
let ptr = unsafe { bytes.as_ptr().add(index + lane * 32) }.cast::<__m256i>();
145+
let chunk = unsafe { _mm256_loadu_si256(ptr) };
146+
let counts = unsafe { byte_popcount(chunk, mask, lookup) };
147+
accum = _mm256_add_epi64(accum, _mm256_sad_epu8(counts, zero));
148+
}
149+
index += 128;
150+
}
151+
152+
while index + 32 <= bytes.len() {
153+
let ptr = unsafe { bytes.as_ptr().add(index) }.cast::<__m256i>();
154+
let chunk = unsafe { _mm256_loadu_si256(ptr) };
155+
let counts = unsafe { byte_popcount(chunk, mask, lookup) };
156+
accum = _mm256_add_epi64(accum, _mm256_sad_epu8(counts, zero));
157+
index += 32;
158+
}
159+
160+
let mut lanes = [0u64; 4];
161+
unsafe { _mm256_storeu_si256(lanes.as_mut_ptr().cast::<__m256i>(), accum) };
162+
163+
usize::try_from(lanes.iter().sum::<u64>()).vortex_expect("true_count doesn't fit in usize")
164+
+ count_ones_aligned_scalar(&bytes[index..])
165+
}
166+
167+
#[cfg(target_arch = "x86_64")]
168+
#[target_feature(enable = "avx512f,avx512vpopcntdq")]
169+
unsafe fn count_ones_aligned_avx512(bytes: &[u8]) -> usize {
170+
use std::arch::x86_64::__m512i;
171+
use std::arch::x86_64::_mm512_add_epi64;
172+
use std::arch::x86_64::_mm512_loadu_si512;
173+
use std::arch::x86_64::_mm512_popcnt_epi64;
174+
use std::arch::x86_64::_mm512_setzero_si512;
175+
use std::arch::x86_64::_mm512_storeu_si512;
176+
177+
let mut accum = _mm512_setzero_si512();
178+
let mut index = 0;
179+
180+
while index + 64 <= bytes.len() {
181+
let ptr = unsafe { bytes.as_ptr().add(index) }.cast::<__m512i>();
182+
let chunk = unsafe { _mm512_loadu_si512(ptr) };
183+
accum = _mm512_add_epi64(accum, _mm512_popcnt_epi64(chunk));
184+
index += 64;
185+
}
186+
187+
let mut lanes = [0u64; 8];
188+
unsafe { _mm512_storeu_si512(lanes.as_mut_ptr().cast::<__m512i>(), accum) };
189+
190+
usize::try_from(lanes.iter().sum::<u64>()).vortex_expect("true_count doesn't fit in usize")
191+
+ count_ones_aligned_scalar(&bytes[index..])
192+
}
193+
194+
#[cfg(test)]
195+
mod tests {
196+
use rstest::rstest;
197+
198+
use crate::BitBuffer;
199+
200+
#[rstest]
201+
fn test_count_ones_matches_iteration_for_slices(
202+
#[values(
203+
0usize, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
204+
23, 24, 25, 26, 27, 28, 29, 30
205+
)]
206+
offset: usize,
207+
#[values(
208+
0usize, 1, 2, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 255, 256, 257, 513
209+
)]
210+
slice_len: usize,
211+
) {
212+
let len = 513;
213+
let buf = BitBuffer::collect_bool(len + 31, |i| (i % 3 == 0) ^ (i % 11 == 0));
214+
215+
if offset + slice_len > buf.len() {
216+
return;
217+
}
218+
219+
let sliced = buf.slice(offset..offset + slice_len);
220+
let expected = sliced.iter().filter(|bit| *bit).count();
221+
222+
assert_eq!(
223+
sliced.true_count(),
224+
expected,
225+
"offset={offset} len={slice_len}"
226+
);
227+
}
228+
}

vortex-buffer/src/bit/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
mod arrow;
1111
mod buf;
1212
mod buf_mut;
13+
mod count_ones;
1314
mod macros;
1415
mod ops;
1516

0 commit comments

Comments
 (0)