|
1 | 1 | // SPDX-License-Identifier: Apache-2.0 |
2 | 2 | // SPDX-FileCopyrightText: Copyright the Vortex contributors |
3 | 3 |
|
4 | | -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] |
5 | | -use std::arch::is_x86_feature_detected; |
| 4 | +use crate::bit::count_ones::count_ones; |
6 | 5 | use std::fmt::Display; |
7 | 6 | use std::fmt::Formatter; |
8 | 7 | use std::fmt::Result as FmtResult; |
@@ -318,7 +317,7 @@ impl BitBuffer { |
318 | 317 |
|
319 | 318 | /// Get the number of set bits in the buffer. |
320 | 319 | pub fn true_count(&self) -> usize { |
321 | | - true_count_impl(self.buffer.as_slice(), self.offset, self.len) |
| 320 | + count_ones(self.buffer.as_slice(), self.offset, self.len) |
322 | 321 | } |
323 | 322 |
|
324 | 323 | /// Get the number of unset bits in the buffer. |
@@ -358,192 +357,6 @@ impl BitBuffer { |
358 | 357 | } |
359 | 358 | } |
360 | 359 |
|
361 | | -#[inline] |
362 | | -fn true_count_impl(bytes: &[u8], offset: usize, len: usize) -> usize { |
363 | | - if bytes.is_empty() { |
364 | | - return 0; |
365 | | - } |
366 | | - |
367 | | - let (head, middle, tail) = byte_aligned_region(bytes, offset, len); |
368 | | - |
369 | | - let mut count = head.map_or(0, |v| v.count_ones() as usize); |
370 | | - |
371 | | - if !middle.is_empty() { |
372 | | - count += count_aligned_bytes(middle); |
373 | | - } |
374 | | - |
375 | | - count + tail.map_or(0, |v| v.count_ones() as usize) |
376 | | -} |
377 | | - |
378 | | -#[inline] |
379 | | -fn byte_aligned_region(bytes: &[u8], offset: usize, len: usize) -> (Option<u8>, &[u8], Option<u8>) { |
380 | | - let start_byte = offset / 8; |
381 | | - let start_bit = offset % 8; |
382 | | - let end_bit = offset + len; |
383 | | - let end_byte = end_bit / 8; |
384 | | - let head = (start_bit != 0).then(|| { |
385 | | - let head_len = (8 - start_bit).min(len); |
386 | | - mask_partial_byte(bytes[start_byte], start_bit, head_len) |
387 | | - }); |
388 | | - |
389 | | - let middle_start = start_byte + usize::from(start_bit != 0); |
390 | | - let middle_end = end_byte; |
391 | | - let middle = if middle_start < middle_end { |
392 | | - &bytes[middle_start..middle_end] |
393 | | - } else { |
394 | | - &[] |
395 | | - }; |
396 | | - |
397 | | - let consumed = if start_bit != 0 { |
398 | | - (8 - start_bit).min(len) |
399 | | - } else { |
400 | | - 0 |
401 | | - } + middle.len() * 8; |
402 | | - let tail_len = len - consumed; |
403 | | - let tail = (tail_len != 0).then(|| mask_partial_byte(bytes[middle_end], 0, tail_len)); |
404 | | - |
405 | | - (head, middle, tail) |
406 | | -} |
407 | | - |
408 | | -#[inline] |
409 | | -fn mask_partial_byte(byte: u8, bit_offset: usize, bit_len: usize) -> u8 { |
410 | | - debug_assert!(bit_offset < 8); |
411 | | - debug_assert!(bit_len <= 8 - bit_offset); |
412 | | - |
413 | | - let shifted = byte >> bit_offset; |
414 | | - let mask = if bit_len == 8 { |
415 | | - u8::MAX |
416 | | - } else { |
417 | | - (1u8 << bit_len) - 1 |
418 | | - }; |
419 | | - |
420 | | - shifted & mask |
421 | | -} |
422 | | - |
423 | | -#[inline] |
424 | | -fn count_aligned_bytes(bytes: &[u8]) -> usize { |
425 | | - #[cfg(target_arch = "x86_64")] |
426 | | - { |
427 | | - if bytes.len() >= 64 |
428 | | - && is_x86_feature_detected!("avx512f") |
429 | | - && is_x86_feature_detected!("avx512vpopcntdq") |
430 | | - { |
431 | | - // SAFETY: Runtime detection guarantees the required target features. |
432 | | - return unsafe { count_aligned_bytes_avx512(bytes) }; |
433 | | - } |
434 | | - |
435 | | - if bytes.len() >= 32 && is_x86_feature_detected!("avx2") { |
436 | | - // SAFETY: Runtime detection guarantees the required target features. |
437 | | - return unsafe { count_aligned_bytes_avx2(bytes) }; |
438 | | - } |
439 | | - } |
440 | | - |
441 | | - count_aligned_bytes_scalar(bytes) |
442 | | -} |
443 | | - |
444 | | -#[inline] |
445 | | -fn count_aligned_bytes_scalar(bytes: &[u8]) -> usize { |
446 | | - let (words, tail) = bytes.as_chunks::<8>(); |
447 | | - let mut count = words |
448 | | - .iter() |
449 | | - .map(|word| u64::from_le_bytes(*word).count_ones() as usize) |
450 | | - .sum::<usize>(); |
451 | | - |
452 | | - count += tail |
453 | | - .iter() |
454 | | - .map(|byte| byte.count_ones() as usize) |
455 | | - .sum::<usize>(); |
456 | | - |
457 | | - count |
458 | | -} |
459 | | - |
460 | | -#[cfg(target_arch = "x86_64")] |
461 | | -#[target_feature(enable = "avx2")] |
462 | | -unsafe fn count_aligned_bytes_avx2(bytes: &[u8]) -> usize { |
463 | | - use std::arch::x86_64::__m256i; |
464 | | - use std::arch::x86_64::_mm256_add_epi8; |
465 | | - use std::arch::x86_64::_mm256_add_epi64; |
466 | | - use std::arch::x86_64::_mm256_and_si256; |
467 | | - use std::arch::x86_64::_mm256_loadu_si256; |
468 | | - use std::arch::x86_64::_mm256_sad_epu8; |
469 | | - use std::arch::x86_64::_mm256_set1_epi8; |
470 | | - use std::arch::x86_64::_mm256_setr_epi8; |
471 | | - use std::arch::x86_64::_mm256_setzero_si256; |
472 | | - use std::arch::x86_64::_mm256_shuffle_epi8; |
473 | | - use std::arch::x86_64::_mm256_srli_epi16; |
474 | | - use std::arch::x86_64::_mm256_storeu_si256; |
475 | | - |
476 | | - #[inline] |
477 | | - unsafe fn byte_popcount(chunk: __m256i, mask: __m256i, lookup: __m256i) -> __m256i { |
478 | | - let lo = unsafe { _mm256_and_si256(chunk, mask) }; |
479 | | - let hi = unsafe { _mm256_and_si256(_mm256_srli_epi16(chunk, 4), mask) }; |
480 | | - unsafe { |
481 | | - _mm256_add_epi8( |
482 | | - _mm256_shuffle_epi8(lookup, lo), |
483 | | - _mm256_shuffle_epi8(lookup, hi), |
484 | | - ) |
485 | | - } |
486 | | - } |
487 | | - |
488 | | - let lookup = _mm256_setr_epi8( |
489 | | - 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, |
490 | | - 3, 4, |
491 | | - ); |
492 | | - let mask = _mm256_set1_epi8(0x0f); |
493 | | - let zero = _mm256_setzero_si256(); |
494 | | - let mut accum = _mm256_setzero_si256(); |
495 | | - let mut index = 0; |
496 | | - |
497 | | - while index + 128 <= bytes.len() { |
498 | | - for lane in 0..4 { |
499 | | - let ptr = unsafe { bytes.as_ptr().add(index + lane * 32) }.cast::<__m256i>(); |
500 | | - let chunk = unsafe { _mm256_loadu_si256(ptr) }; |
501 | | - let counts = unsafe { byte_popcount(chunk, mask, lookup) }; |
502 | | - accum = _mm256_add_epi64(accum, _mm256_sad_epu8(counts, zero)); |
503 | | - } |
504 | | - index += 128; |
505 | | - } |
506 | | - |
507 | | - while index + 32 <= bytes.len() { |
508 | | - let ptr = unsafe { bytes.as_ptr().add(index) }.cast::<__m256i>(); |
509 | | - let chunk = unsafe { _mm256_loadu_si256(ptr) }; |
510 | | - let counts = unsafe { byte_popcount(chunk, mask, lookup) }; |
511 | | - accum = _mm256_add_epi64(accum, _mm256_sad_epu8(counts, zero)); |
512 | | - index += 32; |
513 | | - } |
514 | | - |
515 | | - let mut lanes = [0u64; 4]; |
516 | | - unsafe { _mm256_storeu_si256(lanes.as_mut_ptr().cast::<__m256i>(), accum) }; |
517 | | - |
518 | | - lanes.iter().sum::<u64>() as usize + count_aligned_bytes_scalar(&bytes[index..]) |
519 | | -} |
520 | | - |
521 | | -#[cfg(target_arch = "x86_64")] |
522 | | -#[target_feature(enable = "avx512f,avx512vpopcntdq")] |
523 | | -unsafe fn count_aligned_bytes_avx512(bytes: &[u8]) -> usize { |
524 | | - use std::arch::x86_64::__m512i; |
525 | | - use std::arch::x86_64::_mm512_add_epi64; |
526 | | - use std::arch::x86_64::_mm512_loadu_si512; |
527 | | - use std::arch::x86_64::_mm512_popcnt_epi64; |
528 | | - use std::arch::x86_64::_mm512_setzero_si512; |
529 | | - use std::arch::x86_64::_mm512_storeu_si512; |
530 | | - |
531 | | - let mut accum = _mm512_setzero_si512(); |
532 | | - let mut index = 0; |
533 | | - |
534 | | - while index + 64 <= bytes.len() { |
535 | | - let ptr = unsafe { bytes.as_ptr().add(index) }.cast::<__m512i>(); |
536 | | - let chunk = unsafe { _mm512_loadu_si512(ptr) }; |
537 | | - accum = _mm512_add_epi64(accum, _mm512_popcnt_epi64(chunk)); |
538 | | - index += 64; |
539 | | - } |
540 | | - |
541 | | - let mut lanes = [0u64; 8]; |
542 | | - unsafe { _mm512_storeu_si512(lanes.as_mut_ptr().cast::<__m512i>(), accum) }; |
543 | | - |
544 | | - lanes.iter().sum::<u64>() as usize + count_aligned_bytes_scalar(&bytes[index..]) |
545 | | -} |
546 | | - |
547 | 360 | // Conversions |
548 | 361 |
|
549 | 362 | impl BitBuffer { |
@@ -972,33 +785,4 @@ mod tests { |
972 | 785 | assert_eq!(mapped.value(i), expected, "Mismatch at index {}", i); |
973 | 786 | } |
974 | 787 | } |
975 | | - |
976 | | - #[rstest] |
977 | | - fn test_true_count_matches_iteration_for_slices( |
978 | | - #[values( |
979 | | - 0usize, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, |
980 | | - 23, 24, 25, 26, 27, 28, 29, 30 |
981 | | - )] |
982 | | - offset: usize, |
983 | | - #[values( |
984 | | - 0usize, 1, 2, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 255, 256, 257, 513 |
985 | | - )] |
986 | | - slice_len: usize, |
987 | | - ) { |
988 | | - let len = 513; |
989 | | - let buf = BitBuffer::collect_bool(len + 31, |i| (i % 3 == 0) ^ (i % 11 == 0)); |
990 | | - |
991 | | - if offset + slice_len > buf.len() { |
992 | | - return; |
993 | | - } |
994 | | - |
995 | | - let sliced = buf.slice(offset..offset + slice_len); |
996 | | - let expected = sliced.iter().filter(|bit| *bit).count(); |
997 | | - |
998 | | - assert_eq!( |
999 | | - sliced.true_count(), |
1000 | | - expected, |
1001 | | - "offset={offset} len={slice_len}" |
1002 | | - ); |
1003 | | - } |
1004 | 788 | } |
0 commit comments