|
| 1 | +// SPDX-License-Identifier: Apache-2.0 |
| 2 | +// SPDX-FileCopyrightText: Copyright the Vortex contributors |
| 3 | + |
| 4 | +#[cfg(target_arch = "x86_64")] |
| 5 | +use vortex_error::VortexExpect; |
| 6 | + |
| 7 | +#[inline] |
| 8 | +pub fn count_ones(bytes: &[u8], offset: usize, len: usize) -> usize { |
| 9 | + if bytes.is_empty() { |
| 10 | + return 0; |
| 11 | + } |
| 12 | + |
| 13 | + let (head, middle, tail) = align_offset_len(bytes, offset, len); |
| 14 | + |
| 15 | + let mut count = head.map_or(0, |v| v.count_ones() as usize); |
| 16 | + |
| 17 | + if !middle.is_empty() { |
| 18 | + count += count_ones_aligned(middle); |
| 19 | + } |
| 20 | + |
| 21 | + count + tail.map_or(0, |v| v.count_ones() as usize) |
| 22 | +} |
| 23 | + |
| 24 | +#[inline] |
| 25 | +fn align_offset_len(bytes: &[u8], offset: usize, len: usize) -> (Option<u8>, &[u8], Option<u8>) { |
| 26 | + let start_byte = offset / 8; |
| 27 | + let start_bit = offset % 8; |
| 28 | + let end_bit = offset + len; |
| 29 | + let end_byte = end_bit / 8; |
| 30 | + let head = (start_bit != 0).then(|| { |
| 31 | + let start_len = (8 - start_bit).min(len); |
| 32 | + mask_byte(bytes[start_byte], start_bit, start_len) |
| 33 | + }); |
| 34 | + |
| 35 | + let middle_start = start_byte + usize::from(start_bit != 0); |
| 36 | + let middle_end = end_byte; |
| 37 | + let middle = if middle_start < middle_end { |
| 38 | + &bytes[middle_start..middle_end] |
| 39 | + } else { |
| 40 | + &[] |
| 41 | + }; |
| 42 | + |
| 43 | + let consumed = if start_bit != 0 { |
| 44 | + (8 - start_bit).min(len) |
| 45 | + } else { |
| 46 | + 0 |
| 47 | + } + middle.len() * 8; |
| 48 | + let tail_len = len - consumed; |
| 49 | + let tail = (tail_len != 0).then(|| mask_byte(bytes[middle_end], 0, tail_len)); |
| 50 | + |
| 51 | + (head, middle, tail) |
| 52 | +} |
| 53 | + |
| 54 | +#[inline] |
| 55 | +fn mask_byte(byte: u8, bit_offset: usize, bit_len: usize) -> u8 { |
| 56 | + debug_assert!(bit_offset < 8); |
| 57 | + debug_assert!(bit_len <= 8 - bit_offset); |
| 58 | + |
| 59 | + let shifted = byte >> bit_offset; |
| 60 | + let mask = if bit_len == 8 { |
| 61 | + u8::MAX |
| 62 | + } else { |
| 63 | + (1u8 << bit_len) - 1 |
| 64 | + }; |
| 65 | + |
| 66 | + shifted & mask |
| 67 | +} |
| 68 | + |
| 69 | +#[inline] |
| 70 | +fn count_ones_aligned(bytes: &[u8]) -> usize { |
| 71 | + #[cfg(target_arch = "x86_64")] |
| 72 | + { |
| 73 | + if bytes.len() >= 64 |
| 74 | + && is_x86_feature_detected!("avx512f") |
| 75 | + && is_x86_feature_detected!("avx512vpopcntdq") |
| 76 | + { |
| 77 | + // SAFETY: Runtime detection guarantees the required target features. |
| 78 | + return unsafe { count_ones_aligned_avx512(bytes) }; |
| 79 | + } |
| 80 | + |
| 81 | + if bytes.len() >= 32 && is_x86_feature_detected!("avx2") { |
| 82 | + // SAFETY: Runtime detection guarantees the required target features. |
| 83 | + return unsafe { count_ones_aligned_avx2(bytes) }; |
| 84 | + } |
| 85 | + } |
| 86 | + |
| 87 | + count_ones_aligned_scalar(bytes) |
| 88 | +} |
| 89 | + |
| 90 | +#[inline] |
| 91 | +fn count_ones_aligned_scalar(bytes: &[u8]) -> usize { |
| 92 | + let (words, tail) = bytes.as_chunks::<8>(); |
| 93 | + let count = words |
| 94 | + .iter() |
| 95 | + .map(|word| u64::from_le_bytes(*word).count_ones() as usize) |
| 96 | + .sum::<usize>(); |
| 97 | + |
| 98 | + count |
| 99 | + + tail |
| 100 | + .iter() |
| 101 | + .map(|byte| byte.count_ones() as usize) |
| 102 | + .sum::<usize>() |
| 103 | +} |
| 104 | + |
| 105 | +#[cfg(target_arch = "x86_64")] |
| 106 | +#[target_feature(enable = "avx2")] |
| 107 | +unsafe fn count_ones_aligned_avx2(bytes: &[u8]) -> usize { |
| 108 | + use std::arch::x86_64::__m256i; |
| 109 | + use std::arch::x86_64::_mm256_add_epi8; |
| 110 | + use std::arch::x86_64::_mm256_add_epi64; |
| 111 | + use std::arch::x86_64::_mm256_and_si256; |
| 112 | + use std::arch::x86_64::_mm256_loadu_si256; |
| 113 | + use std::arch::x86_64::_mm256_sad_epu8; |
| 114 | + use std::arch::x86_64::_mm256_set1_epi8; |
| 115 | + use std::arch::x86_64::_mm256_setr_epi8; |
| 116 | + use std::arch::x86_64::_mm256_setzero_si256; |
| 117 | + use std::arch::x86_64::_mm256_shuffle_epi8; |
| 118 | + use std::arch::x86_64::_mm256_srli_epi16; |
| 119 | + use std::arch::x86_64::_mm256_storeu_si256; |
| 120 | + |
| 121 | + #[inline] |
| 122 | + unsafe fn byte_popcount(chunk: __m256i, mask: __m256i, lookup: __m256i) -> __m256i { |
| 123 | + let lo = unsafe { _mm256_and_si256(chunk, mask) }; |
| 124 | + let hi = unsafe { _mm256_and_si256(_mm256_srli_epi16(chunk, 4), mask) }; |
| 125 | + unsafe { |
| 126 | + _mm256_add_epi8( |
| 127 | + _mm256_shuffle_epi8(lookup, lo), |
| 128 | + _mm256_shuffle_epi8(lookup, hi), |
| 129 | + ) |
| 130 | + } |
| 131 | + } |
| 132 | + |
| 133 | + let lookup = _mm256_setr_epi8( |
| 134 | + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, |
| 135 | + 3, 4, |
| 136 | + ); |
| 137 | + let mask = _mm256_set1_epi8(0x0f); |
| 138 | + let zero = _mm256_setzero_si256(); |
| 139 | + let mut accum = _mm256_setzero_si256(); |
| 140 | + let mut index = 0; |
| 141 | + |
| 142 | + while index + 128 <= bytes.len() { |
| 143 | + for lane in 0..4 { |
| 144 | + let ptr = unsafe { bytes.as_ptr().add(index + lane * 32) }.cast::<__m256i>(); |
| 145 | + let chunk = unsafe { _mm256_loadu_si256(ptr) }; |
| 146 | + let counts = unsafe { byte_popcount(chunk, mask, lookup) }; |
| 147 | + accum = _mm256_add_epi64(accum, _mm256_sad_epu8(counts, zero)); |
| 148 | + } |
| 149 | + index += 128; |
| 150 | + } |
| 151 | + |
| 152 | + while index + 32 <= bytes.len() { |
| 153 | + let ptr = unsafe { bytes.as_ptr().add(index) }.cast::<__m256i>(); |
| 154 | + let chunk = unsafe { _mm256_loadu_si256(ptr) }; |
| 155 | + let counts = unsafe { byte_popcount(chunk, mask, lookup) }; |
| 156 | + accum = _mm256_add_epi64(accum, _mm256_sad_epu8(counts, zero)); |
| 157 | + index += 32; |
| 158 | + } |
| 159 | + |
| 160 | + let mut lanes = [0u64; 4]; |
| 161 | + unsafe { _mm256_storeu_si256(lanes.as_mut_ptr().cast::<__m256i>(), accum) }; |
| 162 | + |
| 163 | + usize::try_from(lanes.iter().sum::<u64>()).vortex_expect("true_count doesn't fit in usize") |
| 164 | + + count_ones_aligned_scalar(&bytes[index..]) |
| 165 | +} |
| 166 | + |
| 167 | +#[cfg(target_arch = "x86_64")] |
| 168 | +#[target_feature(enable = "avx512f,avx512vpopcntdq")] |
| 169 | +unsafe fn count_ones_aligned_avx512(bytes: &[u8]) -> usize { |
| 170 | + use std::arch::x86_64::__m512i; |
| 171 | + use std::arch::x86_64::_mm512_add_epi64; |
| 172 | + use std::arch::x86_64::_mm512_loadu_si512; |
| 173 | + use std::arch::x86_64::_mm512_popcnt_epi64; |
| 174 | + use std::arch::x86_64::_mm512_setzero_si512; |
| 175 | + use std::arch::x86_64::_mm512_storeu_si512; |
| 176 | + |
| 177 | + let mut accum = _mm512_setzero_si512(); |
| 178 | + let mut index = 0; |
| 179 | + |
| 180 | + while index + 64 <= bytes.len() { |
| 181 | + let ptr = unsafe { bytes.as_ptr().add(index) }.cast::<__m512i>(); |
| 182 | + let chunk = unsafe { _mm512_loadu_si512(ptr) }; |
| 183 | + accum = _mm512_add_epi64(accum, _mm512_popcnt_epi64(chunk)); |
| 184 | + index += 64; |
| 185 | + } |
| 186 | + |
| 187 | + let mut lanes = [0u64; 8]; |
| 188 | + unsafe { _mm512_storeu_si512(lanes.as_mut_ptr().cast::<__m512i>(), accum) }; |
| 189 | + |
| 190 | + usize::try_from(lanes.iter().sum::<u64>()).vortex_expect("true_count doesn't fit in usize") |
| 191 | + + count_ones_aligned_scalar(&bytes[index..]) |
| 192 | +} |
| 193 | + |
| 194 | +#[cfg(test)] |
| 195 | +mod tests { |
| 196 | + use rstest::rstest; |
| 197 | + |
| 198 | + use crate::BitBuffer; |
| 199 | + |
| 200 | + #[cfg_attr(miri, ignore)] |
| 201 | + #[rstest] |
| 202 | + fn test_count_ones_matches_iteration_for_slices( |
| 203 | + #[values( |
| 204 | + 0usize, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, |
| 205 | + 23, 24, 25, 26, 27, 28, 29, 30 |
| 206 | + )] |
| 207 | + offset: usize, |
| 208 | + #[values( |
| 209 | + 0usize, 1, 2, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 255, 256, 257, 513 |
| 210 | + )] |
| 211 | + slice_len: usize, |
| 212 | + ) { |
| 213 | + let len = 513; |
| 214 | + let buf = BitBuffer::collect_bool(len + 31, |i| (i % 3 == 0) ^ (i % 11 == 0)); |
| 215 | + |
| 216 | + if offset + slice_len > buf.len() { |
| 217 | + return; |
| 218 | + } |
| 219 | + |
| 220 | + let sliced = buf.slice(offset..offset + slice_len); |
| 221 | + let expected = sliced.iter().filter(|bit| *bit).count(); |
| 222 | + |
| 223 | + assert_eq!( |
| 224 | + sliced.true_count(), |
| 225 | + expected, |
| 226 | + "offset={offset} len={slice_len}" |
| 227 | + ); |
| 228 | + } |
| 229 | +} |
0 commit comments