|
1 | 1 | // SPDX-License-Identifier: Apache-2.0 |
2 | 2 | // SPDX-FileCopyrightText: Copyright the Vortex contributors |
3 | 3 |
|
| 4 | +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] |
| 5 | +use std::arch::is_x86_feature_detected; |
4 | 6 | use std::fmt::Display; |
5 | 7 | use std::fmt::Formatter; |
6 | 8 | use std::fmt::Result as FmtResult; |
@@ -316,7 +318,7 @@ impl BitBuffer { |
316 | 318 |
|
317 | 319 | /// Get the number of set bits in the buffer. |
318 | 320 | pub fn true_count(&self) -> usize { |
319 | | - self.unaligned_chunks().count_ones() |
| 321 | + true_count_impl(self.buffer.as_slice(), self.offset, self.len) |
320 | 322 | } |
321 | 323 |
|
322 | 324 | /// Get the number of unset bits in the buffer. |
@@ -356,6 +358,198 @@ impl BitBuffer { |
356 | 358 | } |
357 | 359 | } |
358 | 360 |
|
| 361 | +#[inline] |
| 362 | +fn true_count_impl(bytes: &[u8], offset: usize, len: usize) -> usize { |
| 363 | + let Some((head, middle, tail)) = byte_aligned_region(bytes, offset, len) else { |
| 364 | + return 0; |
| 365 | + }; |
| 366 | + |
| 367 | + let mut count = head.map_or(0, |v| v.count_ones() as usize); |
| 368 | + |
| 369 | + if !middle.is_empty() { |
| 370 | + count += count_aligned_bytes(middle); |
| 371 | + } |
| 372 | + |
| 373 | + count + tail.map_or(0, |v| v.count_ones() as usize) |
| 374 | +} |
| 375 | + |
| 376 | +#[inline] |
| 377 | +fn byte_aligned_region( |
| 378 | + bytes: &[u8], |
| 379 | + offset: usize, |
| 380 | + len: usize, |
| 381 | +) -> Option<(Option<u8>, &[u8], Option<u8>)> { |
| 382 | + if len == 0 { |
| 383 | + return None; |
| 384 | + } |
| 385 | + |
| 386 | + let start_byte = offset / 8; |
| 387 | + let start_bit = offset % 8; |
| 388 | + let end_bit = offset + len; |
| 389 | + let end_byte = end_bit / 8; |
| 390 | + let head = (start_bit != 0).then(|| { |
| 391 | + let head_len = (8 - start_bit).min(len); |
| 392 | + mask_partial_byte(bytes[start_byte], start_bit, head_len) |
| 393 | + }); |
| 394 | + |
| 395 | + let middle_start = start_byte + usize::from(start_bit != 0); |
| 396 | + let middle_end = end_byte; |
| 397 | + let middle = if middle_start < middle_end { |
| 398 | + &bytes[middle_start..middle_end] |
| 399 | + } else { |
| 400 | + &[] |
| 401 | + }; |
| 402 | + |
| 403 | + let consumed = if start_bit != 0 { |
| 404 | + (8 - start_bit).min(len) |
| 405 | + } else { |
| 406 | + 0 |
| 407 | + } + middle.len() * 8; |
| 408 | + let tail_len = len - consumed; |
| 409 | + let tail = (tail_len != 0).then(|| mask_partial_byte(bytes[middle_end], 0, tail_len)); |
| 410 | + |
| 411 | + Some((head, middle, tail)) |
| 412 | +} |
| 413 | + |
| 414 | +#[inline] |
| 415 | +fn mask_partial_byte(byte: u8, bit_offset: usize, bit_len: usize) -> u8 { |
| 416 | + debug_assert!(bit_offset < 8); |
| 417 | + debug_assert!(bit_len <= 8 - bit_offset); |
| 418 | + |
| 419 | + let shifted = byte >> bit_offset; |
| 420 | + let mask = if bit_len == 8 { |
| 421 | + u8::MAX |
| 422 | + } else { |
| 423 | + (1u8 << bit_len) - 1 |
| 424 | + }; |
| 425 | + |
| 426 | + shifted & mask |
| 427 | +} |
| 428 | + |
| 429 | +#[inline] |
| 430 | +fn count_aligned_bytes(bytes: &[u8]) -> usize { |
| 431 | + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] |
| 432 | + { |
| 433 | + if bytes.len() >= 64 |
| 434 | + && is_x86_feature_detected!("avx512f") |
| 435 | + && is_x86_feature_detected!("avx512vpopcntdq") |
| 436 | + { |
| 437 | + // SAFETY: Runtime detection guarantees the required target features. |
| 438 | + return unsafe { count_aligned_bytes_avx512(bytes) }; |
| 439 | + } |
| 440 | + |
| 441 | + if bytes.len() >= 32 && is_x86_feature_detected!("avx2") { |
| 442 | + // SAFETY: Runtime detection guarantees the required target features. |
| 443 | + return unsafe { count_aligned_bytes_avx2(bytes) }; |
| 444 | + } |
| 445 | + } |
| 446 | + |
| 447 | + count_aligned_bytes_scalar(bytes) |
| 448 | +} |
| 449 | + |
| 450 | +#[inline] |
| 451 | +fn count_aligned_bytes_scalar(bytes: &[u8]) -> usize { |
| 452 | + let (words, tail) = bytes.as_chunks::<8>(); |
| 453 | + let mut count = words |
| 454 | + .iter() |
| 455 | + .map(|word| u64::from_le_bytes(*word).count_ones() as usize) |
| 456 | + .sum::<usize>(); |
| 457 | + |
| 458 | + count += tail |
| 459 | + .iter() |
| 460 | + .map(|byte| byte.count_ones() as usize) |
| 461 | + .sum::<usize>(); |
| 462 | + |
| 463 | + count |
| 464 | +} |
| 465 | + |
| 466 | +#[cfg(target_arch = "x86_64")] |
| 467 | +#[target_feature(enable = "avx2")] |
| 468 | +unsafe fn count_aligned_bytes_avx2(bytes: &[u8]) -> usize { |
| 469 | + use std::arch::x86_64::__m256i; |
| 470 | + use std::arch::x86_64::_mm256_add_epi8; |
| 471 | + use std::arch::x86_64::_mm256_add_epi64; |
| 472 | + use std::arch::x86_64::_mm256_and_si256; |
| 473 | + use std::arch::x86_64::_mm256_loadu_si256; |
| 474 | + use std::arch::x86_64::_mm256_sad_epu8; |
| 475 | + use std::arch::x86_64::_mm256_set1_epi8; |
| 476 | + use std::arch::x86_64::_mm256_setr_epi8; |
| 477 | + use std::arch::x86_64::_mm256_setzero_si256; |
| 478 | + use std::arch::x86_64::_mm256_shuffle_epi8; |
| 479 | + use std::arch::x86_64::_mm256_srli_epi16; |
| 480 | + use std::arch::x86_64::_mm256_storeu_si256; |
| 481 | + |
| 482 | + #[inline] |
| 483 | + unsafe fn byte_popcount(chunk: __m256i, mask: __m256i, lookup: __m256i) -> __m256i { |
| 484 | + let lo = unsafe { _mm256_and_si256(chunk, mask) }; |
| 485 | + let hi = unsafe { _mm256_and_si256(_mm256_srli_epi16(chunk, 4), mask) }; |
| 486 | + unsafe { |
| 487 | + _mm256_add_epi8( |
| 488 | + _mm256_shuffle_epi8(lookup, lo), |
| 489 | + _mm256_shuffle_epi8(lookup, hi), |
| 490 | + ) |
| 491 | + } |
| 492 | + } |
| 493 | + |
| 494 | + let lookup = _mm256_setr_epi8( |
| 495 | + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, |
| 496 | + 3, 4, |
| 497 | + ); |
| 498 | + let mask = _mm256_set1_epi8(0x0f); |
| 499 | + let zero = _mm256_setzero_si256(); |
| 500 | + let mut accum = _mm256_setzero_si256(); |
| 501 | + let mut index = 0; |
| 502 | + |
| 503 | + while index + 128 <= bytes.len() { |
| 504 | + for lane in 0..4 { |
| 505 | + let ptr = unsafe { bytes.as_ptr().add(index + lane * 32) }.cast::<__m256i>(); |
| 506 | + let chunk = unsafe { _mm256_loadu_si256(ptr) }; |
| 507 | + let counts = unsafe { byte_popcount(chunk, mask, lookup) }; |
| 508 | + accum = _mm256_add_epi64(accum, _mm256_sad_epu8(counts, zero)); |
| 509 | + } |
| 510 | + index += 128; |
| 511 | + } |
| 512 | + |
| 513 | + while index + 32 <= bytes.len() { |
| 514 | + let ptr = unsafe { bytes.as_ptr().add(index) }.cast::<__m256i>(); |
| 515 | + let chunk = unsafe { _mm256_loadu_si256(ptr) }; |
| 516 | + let counts = unsafe { byte_popcount(chunk, mask, lookup) }; |
| 517 | + accum = _mm256_add_epi64(accum, _mm256_sad_epu8(counts, zero)); |
| 518 | + index += 32; |
| 519 | + } |
| 520 | + |
| 521 | + let mut lanes = [0u64; 4]; |
| 522 | + unsafe { _mm256_storeu_si256(lanes.as_mut_ptr().cast::<__m256i>(), accum) }; |
| 523 | + |
| 524 | + lanes.iter().sum::<u64>() as usize + count_aligned_bytes_scalar(&bytes[index..]) |
| 525 | +} |
| 526 | + |
| 527 | +#[cfg(target_arch = "x86_64")] |
| 528 | +#[target_feature(enable = "avx512f,avx512vpopcntdq")] |
| 529 | +unsafe fn count_aligned_bytes_avx512(bytes: &[u8]) -> usize { |
| 530 | + use std::arch::x86_64::__m512i; |
| 531 | + use std::arch::x86_64::_mm512_add_epi64; |
| 532 | + use std::arch::x86_64::_mm512_loadu_si512; |
| 533 | + use std::arch::x86_64::_mm512_popcnt_epi64; |
| 534 | + use std::arch::x86_64::_mm512_setzero_si512; |
| 535 | + use std::arch::x86_64::_mm512_storeu_si512; |
| 536 | + |
| 537 | + let mut accum = _mm512_setzero_si512(); |
| 538 | + let mut index = 0; |
| 539 | + |
| 540 | + while index + 64 <= bytes.len() { |
| 541 | + let ptr = unsafe { bytes.as_ptr().add(index) }.cast::<__m512i>(); |
| 542 | + let chunk = unsafe { _mm512_loadu_si512(ptr) }; |
| 543 | + accum = _mm512_add_epi64(accum, _mm512_popcnt_epi64(chunk)); |
| 544 | + index += 64; |
| 545 | + } |
| 546 | + |
| 547 | + let mut lanes = [0u64; 8]; |
| 548 | + unsafe { _mm512_storeu_si512(lanes.as_mut_ptr().cast::<__m512i>(), accum) }; |
| 549 | + |
| 550 | + lanes.iter().sum::<u64>() as usize + count_aligned_bytes_scalar(&bytes[index..]) |
| 551 | +} |
| 552 | + |
359 | 553 | // Conversions |
360 | 554 |
|
361 | 555 | impl BitBuffer { |
@@ -784,4 +978,33 @@ mod tests { |
784 | 978 | assert_eq!(mapped.value(i), expected, "Mismatch at index {}", i); |
785 | 979 | } |
786 | 980 | } |
| 981 | + |
| 982 | + #[rstest] |
| 983 | + fn test_true_count_matches_iteration_for_slices( |
| 984 | + #[values( |
| 985 | + 0usize, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, |
| 986 | + 23, 24, 25, 26, 27, 28, 29, 30 |
| 987 | + )] |
| 988 | + offset: usize, |
| 989 | + #[values( |
| 990 | + 0usize, 1, 2, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 255, 256, 257, 513 |
| 991 | + )] |
| 992 | + slice_len: usize, |
| 993 | + ) { |
| 994 | + let len = 513; |
| 995 | + let buf = BitBuffer::collect_bool(len + 31, |i| (i % 3 == 0) ^ (i % 11 == 0)); |
| 996 | + |
| 997 | + if offset + slice_len > buf.len() { |
| 998 | + return; |
| 999 | + } |
| 1000 | + |
| 1001 | + let sliced = buf.slice(offset..offset + slice_len); |
| 1002 | + let expected = sliced.iter().filter(|bit| *bit).count(); |
| 1003 | + |
| 1004 | + assert_eq!( |
| 1005 | + sliced.true_count(), |
| 1006 | + expected, |
| 1007 | + "offset={offset} len={slice_len}" |
| 1008 | + ); |
| 1009 | + } |
787 | 1010 | } |
0 commit comments