Skip to content

Commit 4ad0356

Browse files
committed
bitcount kernel
Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent 27e680a commit 4ad0356

4 files changed

Lines changed: 192 additions & 26 deletions

File tree

vortex-cuda/benches/arrow_validity_cuda.rs

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,56 @@ fn benchmark_arrow_validity_repack(c: &mut Criterion) {
179179
group.finish();
180180
}
181181

182+
fn benchmark_arrow_validity_count_nulls(c: &mut Criterion) {
183+
let mut group = c.benchmark_group("cuda");
184+
185+
for &(len, len_label) in bench_config::BENCH_SIZES {
186+
group.throughput(Throughput::Bytes(
187+
validity_bitmap_byte_len(len, ARROW_OFFSET) as u64,
188+
));
189+
group.bench_with_input(
190+
BenchmarkId::new("cuda/arrow_validity/count_nulls", len_label),
191+
&len,
192+
|b, &len| {
193+
b.iter_custom(|iters| {
194+
let timed = TimedLaunchStrategy::default();
195+
let timer = timed.timer();
196+
197+
let mut cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())
198+
.vortex_expect("failed to create execution context")
199+
.with_launch_strategy(Arc::new(timed));
200+
let source = BitBuffer::collect_bool(len + ARROW_OFFSET, |idx| {
201+
idx >= ARROW_OFFSET && idx % 3 != 0
202+
});
203+
let (_, _, input_buffer) = source.into_inner();
204+
let input_buffer =
205+
block_on(cuda_ctx.ensure_on_device(BufferHandle::new_host(input_buffer)))
206+
.vortex_expect("failed to copy validity input to device");
207+
208+
for _ in 0..iters {
209+
let null_count = test_harness::count_arrow_validity_nulls(
210+
&input_buffer,
211+
len,
212+
ARROW_OFFSET,
213+
&mut cuda_ctx,
214+
)
215+
.vortex_expect("failed to count Arrow validity nulls");
216+
std::hint::black_box(null_count);
217+
}
218+
219+
Duration::from_nanos(timer.load(Ordering::Relaxed))
220+
});
221+
},
222+
);
223+
}
224+
225+
group.finish();
226+
}
227+
182228
criterion::criterion_group! {
183229
name = benches;
184230
config = bench_config::cuda_bench_config();
185-
targets = benchmark_arrow_validity_repack, benchmark_arrow_validity_export
231+
targets = benchmark_arrow_validity_repack, benchmark_arrow_validity_count_nulls, benchmark_arrow_validity_export
186232
}
187233

188234
#[cuda_available]

vortex-cuda/kernels/src/arrow_validity.cu

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,58 @@ __device__ void arrow_validity_repack_device(const uint8_t *const input,
9797
}
9898
}
9999

100+
__device__ uint64_t warp_sum(uint64_t value) {
101+
for (int offset = 16; offset > 0; offset >>= 1) {
102+
value += __shfl_down_sync(0xffffffff, value, offset);
103+
}
104+
return value;
105+
}
106+
107+
__device__ void arrow_validity_count_valid_device(const uint8_t *const input,
108+
uint64_t *const output,
109+
uint64_t len,
110+
uint64_t arrow_offset) {
111+
__shared__ uint64_t warp_counts[32];
112+
113+
const uint32_t thread = threadIdx.x;
114+
const uint64_t worker = blockIdx.x * blockDim.x + thread;
115+
const uint64_t validity_bits = len + arrow_offset;
116+
const uint64_t input_bytes = (validity_bits + 7) / 8;
117+
const uint64_t stride = static_cast<uint64_t>(gridDim.x) * blockDim.x;
118+
119+
uint64_t valid_count = 0;
120+
for (uint64_t byte_idx = worker; byte_idx < input_bytes; byte_idx += stride) {
121+
const uint64_t byte_start = byte_idx * 8;
122+
uint32_t mask = 0xff;
123+
if (byte_start < arrow_offset) {
124+
const uint64_t lead = arrow_offset - byte_start;
125+
mask = lead >= 8 ? 0 : mask << lead;
126+
}
127+
const uint64_t remaining = validity_bits - byte_start;
128+
if (remaining < 8) {
129+
mask &= (uint32_t {1} << remaining) - 1;
130+
}
131+
valid_count += __popc(static_cast<uint32_t>(input[byte_idx]) & mask);
132+
}
133+
134+
const uint32_t lane = thread & 31;
135+
const uint32_t warp = thread >> 5;
136+
valid_count = warp_sum(valid_count);
137+
if (lane == 0) {
138+
warp_counts[warp] = valid_count;
139+
}
140+
__syncthreads();
141+
142+
valid_count = thread < (blockDim.x + 31) / 32 ? warp_counts[lane] : 0;
143+
if (warp == 0) {
144+
valid_count = warp_sum(valid_count);
145+
if (lane == 0) {
146+
atomicAdd(reinterpret_cast<unsigned long long *>(output),
147+
static_cast<unsigned long long>(valid_count));
148+
}
149+
}
150+
}
151+
100152
} // namespace
101153

102154
// CUDA entry point for validity bitmap repacking used by Arrow Device export.
@@ -108,3 +160,11 @@ extern "C" __global__ void arrow_validity_repack(const uint8_t *const input,
108160
uint64_t input_bytes) {
109161
arrow_validity_repack_device(input, output, len, input_offset, arrow_offset, input_bytes);
110162
}
163+
164+
// Kernel entry point for counting valid bits in an Arrow validity bitmap.
165+
extern "C" __global__ void arrow_validity_count_valid(const uint8_t *const input,
166+
uint64_t *const output,
167+
uint64_t len,
168+
uint64_t arrow_offset) {
169+
arrow_validity_count_valid_device(input, output, len, arrow_offset);
170+
}

vortex-cuda/src/arrow/canonical.rs

Lines changed: 74 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,9 @@ pub(super) async fn export_arrow_validity_buffer(
804804
} else {
805805
repack_arrow_validity_buffer(&bitmap, meta.offset(), len, arrow_offset, ctx)?
806806
};
807-
Ok((Some(bitmap), ARROW_UNKNOWN_NULL_COUNT))
807+
// Keep nullable exports self-describing for consumers that require exact null counts.
808+
let null_count = count_arrow_validity_nulls(&bitmap, len, arrow_offset, ctx)?;
809+
Ok((Some(bitmap), null_count))
808810
}
809811
}
810812
}
@@ -817,9 +819,6 @@ fn validity_bitmap_byte_len(len: usize, arrow_offset: usize) -> VortexResult<usi
817819
.div_ceil(8))
818820
}
819821

820-
/// Arrow C Data uses -1 when the null count has not been computed.
821-
const ARROW_UNKNOWN_NULL_COUNT: i64 = -1;
822-
823822
/// Allocate a zeroed device buffer with cuDF-safe padding for Arrow validity masks.
824823
fn device_zeroed_byte_buffer(
825824
byte_len: usize,
@@ -833,6 +832,62 @@ fn device_zeroed_byte_buffer(
833832
Ok(BufferHandle::new_device(Arc::new(CudaDeviceBuffer::new(buffer))).slice(0..byte_len))
834833
}
835834

835+
pub(super) fn count_arrow_validity_nulls(
836+
bitmap: &BufferHandle,
837+
len: usize,
838+
arrow_offset: usize,
839+
ctx: &mut CudaExecutionCtx,
840+
) -> VortexResult<i64> {
841+
if len == 0 {
842+
return Ok(0);
843+
}
844+
845+
let expected_bytes = validity_bitmap_byte_len(len, arrow_offset)?;
846+
vortex_ensure!(
847+
bitmap.len() >= expected_bytes,
848+
"Arrow validity bitmap has {} bytes, expected at least {expected_bytes}",
849+
bitmap.len()
850+
);
851+
852+
let mut count = ctx.device_alloc::<u64>(1)?;
853+
ctx.stream()
854+
.memset_zeros(&mut count)
855+
.map_err(|err| vortex_err!("Failed to zero Arrow validity count buffer: {err}"))?;
856+
let count = CudaDeviceBuffer::new(count);
857+
858+
let input_view = bitmap.cuda_view::<u8>()?;
859+
let output_view = count.as_view::<u64>();
860+
let len = u64::try_from(len)?;
861+
let arrow_offset = u64::try_from(arrow_offset)?;
862+
863+
let kernel = ctx.load_function_with_suffixes("arrow_validity", &["count_valid"])?;
864+
const COUNT_THREADS_PER_BLOCK: u32 = 256;
865+
const MAX_COUNT_BLOCKS: u32 = 4096;
866+
let num_blocks = u32::try_from(expected_bytes.div_ceil(COUNT_THREADS_PER_BLOCK as usize))?
867+
.clamp(1, MAX_COUNT_BLOCKS);
868+
let config = LaunchConfig {
869+
grid_dim: (num_blocks, 1, 1),
870+
block_dim: (COUNT_THREADS_PER_BLOCK, 1, 1),
871+
shared_mem_bytes: 0,
872+
};
873+
ctx.launch_kernel_config(&kernel, config, expected_bytes, |args| {
874+
args.arg(&input_view)
875+
.arg(&output_view)
876+
.arg(&len)
877+
.arg(&arrow_offset);
878+
})?;
879+
880+
let valid_count = ctx
881+
.stream()
882+
.clone_dtoh(&output_view)
883+
.map_err(|err| vortex_err!("Failed to copy Arrow validity count to host: {err}"))?
884+
.into_iter()
885+
.next()
886+
.ok_or_else(|| vortex_err!("Arrow validity count kernel returned no output"))?;
887+
888+
Ok(i64::try_from(len - valid_count)?)
889+
}
890+
836891
/// Repack a validity bitmap into Arrow layout without copying bitmap bits back to the CPU.
837892
///
838893
/// Vortex bitmaps may start at any bit offset. Arrow exposes only a byte-addressed validity buffer
@@ -1665,10 +1720,7 @@ mod tests {
16651720
)
16661721
);
16671722
assert_eq!(exported.array.array.length, 4);
1668-
assert_eq!(
1669-
exported.array.array.null_count,
1670-
super::ARROW_UNKNOWN_NULL_COUNT
1671-
);
1723+
assert_eq!(exported.array.array.null_count, 1);
16721724
assert_eq!(exported.array.array.n_buffers, 2);
16731725
assert_eq!(exported.array.array.n_children, 0);
16741726
assert!(!exported.array.array.dictionary.is_null());
@@ -1726,7 +1778,7 @@ mod tests {
17261778
let children = unsafe { std::slice::from_raw_parts(exported.array.array.children, 1) };
17271779
let dict_child = unsafe { &*children[0] };
17281780
assert!(!dict_child.dictionary.is_null());
1729-
assert_eq!(dict_child.null_count, super::ARROW_UNKNOWN_NULL_COUNT);
1781+
assert_eq!(dict_child.null_count, 1);
17301782

17311783
unsafe { release_exported_array(&raw mut exported.array.array) };
17321784
Ok(())
@@ -1759,7 +1811,7 @@ mod tests {
17591811
[0, 1, 0]
17601812
);
17611813
let dictionary = unsafe { &*exported.array.array.dictionary };
1762-
assert_eq!(dictionary.null_count, super::ARROW_UNKNOWN_NULL_COUNT);
1814+
assert_eq!(dictionary.null_count, 1);
17631815
assert_eq!(private_data_buffer_i32_values(dictionary, 1)?.len(), 2);
17641816

17651817
unsafe { release_exported_array(&raw mut exported.array.array) };
@@ -2036,7 +2088,7 @@ mod tests {
20362088
assert_binary_layout(
20372089
&exported.array.array,
20382090
5,
2039-
super::ARROW_UNKNOWN_NULL_COUNT,
2091+
1,
20402092
&[0, 0, 3, 3, 8, i32::try_from(8 + out_of_line.len())?],
20412093
&[b"\x00\xff\xfe".as_slice(), b"short", out_of_line].concat(),
20422094
)?;
@@ -2550,7 +2602,7 @@ mod tests {
25502602
[0, 1, 0, 1, 2]
25512603
);
25522604
let dictionary = unsafe { &*elements.dictionary };
2553-
assert_eq!(dictionary.null_count, super::ARROW_UNKNOWN_NULL_COUNT);
2605+
assert_eq!(dictionary.null_count, 1);
25542606

25552607
unsafe { release_exported_array(&raw mut exported.array.array) };
25562608
Ok(())
@@ -2796,7 +2848,7 @@ mod tests {
27962848
let mut exported = utf8.export_device_array_with_schema(&mut ctx).await?;
27972849
let field = Field::try_from(&exported.schema)?;
27982850
assert_eq!(field, Field::new("", DataType::Utf8View, true));
2799-
assert_varbinview_shape(&exported.array.array, 3, super::ARROW_UNKNOWN_NULL_COUNT)?;
2851+
assert_varbinview_shape(&exported.array.array, 3, 1)?;
28002852
assert_eq!(exported.array.device_type, ARROW_DEVICE_CUDA);
28012853

28022854
let private_data = unsafe { &*exported.array.array.private_data.cast::<PrivateData>() };
@@ -2824,7 +2876,7 @@ mod tests {
28242876
assert_binary_layout(
28252877
&exported.array.array,
28262878
3,
2827-
super::ARROW_UNKNOWN_NULL_COUNT,
2879+
1,
28282880
&[0, 0, 2, i32::try_from(2 + sliced_out_of_line.len())?],
28292881
&[b"\x00\xff".as_slice(), sliced_out_of_line].concat(),
28302882
)?;
@@ -2942,7 +2994,7 @@ mod tests {
29422994
let mut primitive = assert_nullable_export(
29432995
PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]).into_array(),
29442996
2,
2945-
super::ARROW_UNKNOWN_NULL_COUNT,
2997+
1,
29462998
&mut ctx,
29472999
)
29483000
.await?;
@@ -2971,7 +3023,7 @@ mod tests {
29713023
.into_array()
29723024
.slice(1..4)?,
29733025
2,
2974-
super::ARROW_UNKNOWN_NULL_COUNT,
3026+
1,
29753027
&mut ctx,
29763028
)
29773029
.await?;
@@ -3020,7 +3072,7 @@ mod tests {
30203072
)
30213073
.into_array(),
30223074
2,
3023-
super::ARROW_UNKNOWN_NULL_COUNT,
3075+
1,
30243076
&mut ctx,
30253077
)
30263078
.await?;
@@ -3050,7 +3102,7 @@ mod tests {
30503102
)
30513103
.into_array(),
30523104
2,
3053-
super::ARROW_UNKNOWN_NULL_COUNT,
3105+
1,
30543106
&mut ctx,
30553107
)
30563108
.await?;
@@ -3073,7 +3125,7 @@ mod tests {
30733125
])
30743126
.into_array(),
30753127
4,
3076-
super::ARROW_UNKNOWN_NULL_COUNT,
3128+
1,
30773129
&mut ctx,
30783130
)
30793131
.await?;
@@ -3097,7 +3149,7 @@ mod tests {
30973149
)?
30983150
.into_array(),
30993151
1,
3100-
super::ARROW_UNKNOWN_NULL_COUNT,
3152+
1,
31013153
&mut ctx,
31023154
)
31033155
.await?;
@@ -3106,10 +3158,7 @@ mod tests {
31063158
Ok(())
31073159
}
31083160

3109-
// A container's row validity may be a non-canonical bool array (e.g. dict-encoded, or
3110-
// produced by take/scan). The container export paths bypass execute_cuda, so
3111-
// export_arrow_validity_buffer must canonicalize the validity on the GPU itself instead of
3112-
// bailing.
3161+
// Non-canonical row validity should export as a device-resident bitmap.
31133162
#[crate::test]
31143163
async fn test_export_struct_non_canonical_validity() -> VortexResult<()> {
31153164
let mut ctx = CudaSession::create_execution_ctx(&VortexSession::empty())
@@ -3131,7 +3180,7 @@ mod tests {
31313180
)?
31323181
.into_array(),
31333182
1,
3134-
super::ARROW_UNKNOWN_NULL_COUNT,
3183+
1,
31353184
&mut ctx,
31363185
)
31373186
.await?;

vortex-cuda/src/arrow/mod.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,19 @@ pub mod test_harness {
7070
use vortex::error::VortexResult;
7171

7272
use crate::CudaExecutionCtx;
73+
use crate::arrow::canonical::count_arrow_validity_nulls as count_arrow_validity_nulls_impl;
7374
use crate::arrow::canonical::repack_arrow_validity_buffer as repack_arrow_validity_buffer_impl;
7475

76+
/// Count null bits in an Arrow validity bitmap.
77+
pub fn count_arrow_validity_nulls(
78+
bitmap: &BufferHandle,
79+
len: usize,
80+
arrow_offset: usize,
81+
ctx: &mut CudaExecutionCtx,
82+
) -> VortexResult<i64> {
83+
count_arrow_validity_nulls_impl(bitmap, len, arrow_offset, ctx)
84+
}
85+
7586
/// Repack a validity bitmap into Arrow's byte-addressed bitmap layout on the active stream.
7687
pub fn repack_arrow_validity_buffer(
7788
input_buffer: &BufferHandle,

0 commit comments

Comments
 (0)