Skip to content

Commit 44dacba

Browse files
committed
refactor
Signed-off-by: Robert Kruszewski <github@robertk.io>
1 parent 38d1e05 commit 44dacba

3 files changed

Lines changed: 216 additions & 218 deletions

File tree

vortex-buffer/src/bit/buf.rs

Lines changed: 2 additions & 218 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
5-
use std::arch::is_x86_feature_detected;
4+
use crate::bit::count_ones::count_ones;
65
use std::fmt::Display;
76
use std::fmt::Formatter;
87
use std::fmt::Result as FmtResult;
@@ -318,7 +317,7 @@ impl BitBuffer {
318317

319318
/// Get the number of set bits in the buffer.
320319
pub fn true_count(&self) -> usize {
321-
true_count_impl(self.buffer.as_slice(), self.offset, self.len)
320+
count_ones(self.buffer.as_slice(), self.offset, self.len)
322321
}
323322

324323
/// Get the number of unset bits in the buffer.
@@ -358,192 +357,6 @@ impl BitBuffer {
358357
}
359358
}
360359

361-
#[inline]
362-
fn true_count_impl(bytes: &[u8], offset: usize, len: usize) -> usize {
363-
if bytes.is_empty() {
364-
return 0;
365-
}
366-
367-
let (head, middle, tail) = byte_aligned_region(bytes, offset, len);
368-
369-
let mut count = head.map_or(0, |v| v.count_ones() as usize);
370-
371-
if !middle.is_empty() {
372-
count += count_aligned_bytes(middle);
373-
}
374-
375-
count + tail.map_or(0, |v| v.count_ones() as usize)
376-
}
377-
378-
#[inline]
379-
fn byte_aligned_region(bytes: &[u8], offset: usize, len: usize) -> (Option<u8>, &[u8], Option<u8>) {
380-
let start_byte = offset / 8;
381-
let start_bit = offset % 8;
382-
let end_bit = offset + len;
383-
let end_byte = end_bit / 8;
384-
let head = (start_bit != 0).then(|| {
385-
let head_len = (8 - start_bit).min(len);
386-
mask_partial_byte(bytes[start_byte], start_bit, head_len)
387-
});
388-
389-
let middle_start = start_byte + usize::from(start_bit != 0);
390-
let middle_end = end_byte;
391-
let middle = if middle_start < middle_end {
392-
&bytes[middle_start..middle_end]
393-
} else {
394-
&[]
395-
};
396-
397-
let consumed = if start_bit != 0 {
398-
(8 - start_bit).min(len)
399-
} else {
400-
0
401-
} + middle.len() * 8;
402-
let tail_len = len - consumed;
403-
let tail = (tail_len != 0).then(|| mask_partial_byte(bytes[middle_end], 0, tail_len));
404-
405-
(head, middle, tail)
406-
}
407-
408-
#[inline]
409-
fn mask_partial_byte(byte: u8, bit_offset: usize, bit_len: usize) -> u8 {
410-
debug_assert!(bit_offset < 8);
411-
debug_assert!(bit_len <= 8 - bit_offset);
412-
413-
let shifted = byte >> bit_offset;
414-
let mask = if bit_len == 8 {
415-
u8::MAX
416-
} else {
417-
(1u8 << bit_len) - 1
418-
};
419-
420-
shifted & mask
421-
}
422-
423-
#[inline]
424-
fn count_aligned_bytes(bytes: &[u8]) -> usize {
425-
#[cfg(target_arch = "x86_64")]
426-
{
427-
if bytes.len() >= 64
428-
&& is_x86_feature_detected!("avx512f")
429-
&& is_x86_feature_detected!("avx512vpopcntdq")
430-
{
431-
// SAFETY: Runtime detection guarantees the required target features.
432-
return unsafe { count_aligned_bytes_avx512(bytes) };
433-
}
434-
435-
if bytes.len() >= 32 && is_x86_feature_detected!("avx2") {
436-
// SAFETY: Runtime detection guarantees the required target features.
437-
return unsafe { count_aligned_bytes_avx2(bytes) };
438-
}
439-
}
440-
441-
count_aligned_bytes_scalar(bytes)
442-
}
443-
444-
#[inline]
445-
fn count_aligned_bytes_scalar(bytes: &[u8]) -> usize {
446-
let (words, tail) = bytes.as_chunks::<8>();
447-
let mut count = words
448-
.iter()
449-
.map(|word| u64::from_le_bytes(*word).count_ones() as usize)
450-
.sum::<usize>();
451-
452-
count += tail
453-
.iter()
454-
.map(|byte| byte.count_ones() as usize)
455-
.sum::<usize>();
456-
457-
count
458-
}
459-
460-
#[cfg(target_arch = "x86_64")]
461-
#[target_feature(enable = "avx2")]
462-
unsafe fn count_aligned_bytes_avx2(bytes: &[u8]) -> usize {
463-
use std::arch::x86_64::__m256i;
464-
use std::arch::x86_64::_mm256_add_epi8;
465-
use std::arch::x86_64::_mm256_add_epi64;
466-
use std::arch::x86_64::_mm256_and_si256;
467-
use std::arch::x86_64::_mm256_loadu_si256;
468-
use std::arch::x86_64::_mm256_sad_epu8;
469-
use std::arch::x86_64::_mm256_set1_epi8;
470-
use std::arch::x86_64::_mm256_setr_epi8;
471-
use std::arch::x86_64::_mm256_setzero_si256;
472-
use std::arch::x86_64::_mm256_shuffle_epi8;
473-
use std::arch::x86_64::_mm256_srli_epi16;
474-
use std::arch::x86_64::_mm256_storeu_si256;
475-
476-
#[inline]
477-
unsafe fn byte_popcount(chunk: __m256i, mask: __m256i, lookup: __m256i) -> __m256i {
478-
let lo = unsafe { _mm256_and_si256(chunk, mask) };
479-
let hi = unsafe { _mm256_and_si256(_mm256_srli_epi16(chunk, 4), mask) };
480-
unsafe {
481-
_mm256_add_epi8(
482-
_mm256_shuffle_epi8(lookup, lo),
483-
_mm256_shuffle_epi8(lookup, hi),
484-
)
485-
}
486-
}
487-
488-
let lookup = _mm256_setr_epi8(
489-
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3,
490-
3, 4,
491-
);
492-
let mask = _mm256_set1_epi8(0x0f);
493-
let zero = _mm256_setzero_si256();
494-
let mut accum = _mm256_setzero_si256();
495-
let mut index = 0;
496-
497-
while index + 128 <= bytes.len() {
498-
for lane in 0..4 {
499-
let ptr = unsafe { bytes.as_ptr().add(index + lane * 32) }.cast::<__m256i>();
500-
let chunk = unsafe { _mm256_loadu_si256(ptr) };
501-
let counts = unsafe { byte_popcount(chunk, mask, lookup) };
502-
accum = _mm256_add_epi64(accum, _mm256_sad_epu8(counts, zero));
503-
}
504-
index += 128;
505-
}
506-
507-
while index + 32 <= bytes.len() {
508-
let ptr = unsafe { bytes.as_ptr().add(index) }.cast::<__m256i>();
509-
let chunk = unsafe { _mm256_loadu_si256(ptr) };
510-
let counts = unsafe { byte_popcount(chunk, mask, lookup) };
511-
accum = _mm256_add_epi64(accum, _mm256_sad_epu8(counts, zero));
512-
index += 32;
513-
}
514-
515-
let mut lanes = [0u64; 4];
516-
unsafe { _mm256_storeu_si256(lanes.as_mut_ptr().cast::<__m256i>(), accum) };
517-
518-
lanes.iter().sum::<u64>() as usize + count_aligned_bytes_scalar(&bytes[index..])
519-
}
520-
521-
#[cfg(target_arch = "x86_64")]
522-
#[target_feature(enable = "avx512f,avx512vpopcntdq")]
523-
unsafe fn count_aligned_bytes_avx512(bytes: &[u8]) -> usize {
524-
use std::arch::x86_64::__m512i;
525-
use std::arch::x86_64::_mm512_add_epi64;
526-
use std::arch::x86_64::_mm512_loadu_si512;
527-
use std::arch::x86_64::_mm512_popcnt_epi64;
528-
use std::arch::x86_64::_mm512_setzero_si512;
529-
use std::arch::x86_64::_mm512_storeu_si512;
530-
531-
let mut accum = _mm512_setzero_si512();
532-
let mut index = 0;
533-
534-
while index + 64 <= bytes.len() {
535-
let ptr = unsafe { bytes.as_ptr().add(index) }.cast::<__m512i>();
536-
let chunk = unsafe { _mm512_loadu_si512(ptr) };
537-
accum = _mm512_add_epi64(accum, _mm512_popcnt_epi64(chunk));
538-
index += 64;
539-
}
540-
541-
let mut lanes = [0u64; 8];
542-
unsafe { _mm512_storeu_si512(lanes.as_mut_ptr().cast::<__m512i>(), accum) };
543-
544-
lanes.iter().sum::<u64>() as usize + count_aligned_bytes_scalar(&bytes[index..])
545-
}
546-
547360
// Conversions
548361

549362
impl BitBuffer {
@@ -972,33 +785,4 @@ mod tests {
972785
assert_eq!(mapped.value(i), expected, "Mismatch at index {}", i);
973786
}
974787
}
975-
976-
#[rstest]
977-
fn test_true_count_matches_iteration_for_slices(
978-
#[values(
979-
0usize, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
980-
23, 24, 25, 26, 27, 28, 29, 30
981-
)]
982-
offset: usize,
983-
#[values(
984-
0usize, 1, 2, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 255, 256, 257, 513
985-
)]
986-
slice_len: usize,
987-
) {
988-
let len = 513;
989-
let buf = BitBuffer::collect_bool(len + 31, |i| (i % 3 == 0) ^ (i % 11 == 0));
990-
991-
if offset + slice_len > buf.len() {
992-
return;
993-
}
994-
995-
let sliced = buf.slice(offset..offset + slice_len);
996-
let expected = sliced.iter().filter(|bit| *bit).count();
997-
998-
assert_eq!(
999-
sliced.true_count(),
1000-
expected,
1001-
"offset={offset} len={slice_len}"
1002-
);
1003-
}
1004788
}

0 commit comments

Comments
 (0)