77
88namespace {
99
10- // Load the `word_idx`-th little-endian u64 of `input`, treating bytes outside
11- // `[0, input_bytes)` as zero. `input` must be 8-byte aligned.
10+ // Transform up to 8 input bytes into a zero-extended 64-bit word:
11+ //
12+ // [ b0 ][ b1 ][ b2 ] | end -> [ b0 ][ b1 ][ b2 ][ 00 ][ 00 ][ 00 ][ 00 ][ 00 ]
1213__device__ uint64_t load_input_word (const uint8_t *const input, int64_t word_idx, uint64_t input_bytes) {
1314 if (word_idx < 0 ) {
1415 return 0 ;
@@ -28,11 +29,17 @@ __device__ uint64_t load_input_word(const uint8_t *const input, int64_t word_idx
2829 return word;
2930}
3031
31- // Build one 64-bit word of the Arrow validity bitmap.
32+ // Build one output word for sliced validity. The row bits are the same, but
33+ // row 0 may live at a different bit position in the source and Arrow bitmaps.
34+ // For example, `input_offset = 5` and `arrow_offset = 0` shifts row0 from bit 5
35+ // in the input bitmap to bit 0 in the Arrow bitmap.
36+ //
37+ // input bitmap: [ . ][ . ][ . ][ . ][ . ][ row0 ][ row1 ][ row2 ]....
38+ // ^ input_offset
39+ // Arrow bitmap: [ row0 ][ row1 ][ row2 ]....
40+ // ^ arrow_offset
3241//
33- // Output bit `b` for `b` in `[arrow_offset, validity_bits)` equals input bit `b + shift`;
34- // all other bits are zero. Two adjacent input words are funnel-shifted to align the input
35- // bits with the output word, then the leading/trailing edges are masked.
42+ // Padding bits are cleared so word-sized validity readers can safely over-read.
3643__device__ uint64_t repack_word (const uint8_t *const input,
3744 uint64_t word_idx,
3845 int64_t shift,
@@ -56,115 +63,129 @@ __device__ uint64_t repack_word(const uint8_t *const input,
5663 return 0 ;
5764 }
5865
59- // `>> 6` floors also for negative bit positions, unlike `/ 64` which truncates toward zero.
60- const int64_t input_bit = static_cast <int64_t >(word_start) + shift;
61- const int64_t input_word = input_bit >> 6 ;
62- const uint32_t bit = static_cast <uint32_t >(input_bit & 63 );
66+ // Each output bit `b` reads source bit `b + shift`.
67+ // `>> 6` floors for negative positions, unlike `/ 64` which truncates toward zero.
68+ const int64_t source_bit_start = static_cast <int64_t >(word_start) + shift;
69+ const int64_t source_word = source_bit_start >> 6 ;
70+ const uint32_t source_bit = static_cast <uint32_t >(source_bit_start & 63 );
6371
64- const uint64_t lo = load_input_word (input, input_word , input_bytes);
65- if (bit == 0 ) {
72+ const uint64_t lo = load_input_word (input, source_word , input_bytes);
73+ if (source_bit == 0 ) {
6674 return lo & mask;
6775 }
68- const uint64_t hi = load_input_word (input, input_word + 1 , input_bytes);
69- return ((lo >> bit ) | (hi << (64 - bit ))) & mask;
76+ const uint64_t hi = load_input_word (input, source_word + 1 , input_bytes);
77+ return ((lo >> source_bit ) | (hi << (64 - source_bit ))) & mask;
7078}
7179
72- // Rebuild a possibly bit-offset Vortex validity bitmap into an Arrow-compatible bitmap.
73- //
74- // `input_offset` is the bit offset into `input`; `arrow_offset` is the logical Arrow array offset
75- // to preserve in the output. Bits outside `[arrow_offset, arrow_offset + len)` are left unset.
76- // The output allocation must hold `ceil((len + arrow_offset) / 64)` full 64-bit words; every
77- // word is written, so no zero-initialization of the output is required.
78- __device__ void arrow_validity_repack_device (const uint8_t *const input,
79- uint64_t *const output,
80- uint64_t len,
81- uint64_t input_offset,
82- uint64_t arrow_offset,
83- uint64_t input_bytes) {
84- // One worker owns a contiguous range of output words. Each word is rebuilt locally so
85- // there are no cross-thread bit writes or atomics.
86- const uint64_t worker = blockIdx .x * blockDim .x + threadIdx .x ;
87- const uint64_t validity_bits = len + arrow_offset;
88- const uint64_t output_words = (validity_bits + 63 ) / 64 ;
89- const uint64_t stride = static_cast <uint64_t >(gridDim .x ) * blockDim .x ;
90-
91- // Translate Arrow-visible output bits back to source bitmap bits. The source bitmap may
92- // start at any bit offset, while Arrow's buffer pointer is byte-addressed.
93- const int64_t shift = static_cast <int64_t >(input_offset) - static_cast <int64_t >(arrow_offset);
94-
95- for (uint64_t word_idx = worker; word_idx < output_words; word_idx += stride) {
96- output[word_idx] = repack_word (input, word_idx, shift, arrow_offset, validity_bits, input_bytes);
97- }
98- }
80+ constexpr uint32_t WARP_SIZE = 32 ;
81+ constexpr uint32_t FULL_WARP_MASK = 0xffffffff ;
9982
83+ // First reduction step for the count kernel: sum one value per lane so each
84+ // warp produces a single partial count.
85+ //
86+ // lanes: [a][b][c][d]... -> lane 0: a+b+c+d+...
10087__device__ uint64_t warp_sum (uint64_t value) {
101- for (int offset = 16 ; offset > 0 ; offset >>= 1 ) {
102- value += __shfl_down_sync (0xffffffff , value, offset);
88+ for (int offset = WARP_SIZE / 2 ; offset > 0 ; offset >>= 1 ) {
89+ value += __shfl_down_sync (FULL_WARP_MASK , value, offset);
10390 }
10491 return value;
10592}
10693
107- __device__ void arrow_validity_count_valid_device (const uint8_t *const input,
108- uint64_t *const output,
109- uint64_t len,
110- uint64_t arrow_offset) {
111- __shared__ uint64_t warp_counts[32 ];
94+ // Mask one bitmap byte down to actual rows. This keeps null counting from
95+ // including Arrow offset padding or trailing padding bits.
96+ //
97+ // byte bits: [ pad ][ row ][ row ][ row ][ pad ]
98+ // mask: [ 0 ][ 1 ][ 1 ][ 1 ][ 0 ]
99+ __device__ uint32_t arrow_validity_byte_mask (uint64_t byte_idx,
100+ uint64_t arrow_offset,
101+ uint64_t validity_bits) {
102+ const uint64_t byte_start = byte_idx * 8 ;
112103
113- const uint32_t thread = threadIdx . x ;
114- const uint64_t worker = blockIdx . x * blockDim . x + thread;
115- const uint64_t validity_bits = len + arrow_offset ;
116- const uint64_t input_bytes = (validity_bits + 7 ) / 8 ;
117- const uint64_t stride = static_cast < uint64_t >( gridDim . x ) * blockDim . x ;
104+ uint32_t mask = 0xff ;
105+ if (byte_start < arrow_offset) {
106+ const uint64_t lead = arrow_offset - byte_start ;
107+ mask = lead >= 8 ? 0 : mask << lead ;
108+ }
118109
119- uint64_t valid_count = 0 ;
120- for (uint64_t byte_idx = worker; byte_idx < input_bytes; byte_idx += stride) {
121- const uint64_t byte_start = byte_idx * 8 ;
122- uint32_t mask = 0xff ;
123- if (byte_start < arrow_offset) {
124- const uint64_t lead = arrow_offset - byte_start;
125- mask = lead >= 8 ? 0 : mask << lead;
126- }
127- const uint64_t remaining = validity_bits - byte_start;
128- if (remaining < 8 ) {
129- mask &= (uint32_t {1 } << remaining) - 1 ;
130- }
131- valid_count += __popc (static_cast <uint32_t >(input[byte_idx]) & mask);
110+ const uint64_t remaining = validity_bits - byte_start;
111+ if (remaining < 8 ) {
112+ mask &= (uint32_t {1 } << remaining) - 1 ;
132113 }
114+ return mask;
115+ }
133116
134- const uint32_t lane = thread & 31 ;
135- const uint32_t warp = thread >> 5 ;
136- valid_count = warp_sum (valid_count);
117+ // Combine warp partial counts into one block total. Only thread 0 returns a
118+ // non-zero value so the count kernel does one global atomic per block.
119+ //
120+ // per-thread counts -> per-warp sums -> block sum -> atomicAdd
121+ __device__ uint64_t block_sum_to_thread_zero (uint64_t value, uint64_t *const warp_counts) {
122+ const uint32_t thread = threadIdx .x ;
123+ const uint32_t lane = thread & (WARP_SIZE - 1 );
124+ const uint32_t warp = thread / WARP_SIZE ;
125+ const uint32_t block_warps = (blockDim .x + WARP_SIZE - 1 ) / WARP_SIZE ;
126+
127+ value = warp_sum (value);
137128 if (lane == 0 ) {
138- warp_counts[warp] = valid_count ;
129+ warp_counts[warp] = value ;
139130 }
140131 __syncthreads ();
141132
142- valid_count = thread < (blockDim .x + 31 ) / 32 ? warp_counts[lane] : 0 ;
143- if (warp == 0 ) {
144- valid_count = warp_sum (valid_count);
145- if (lane == 0 ) {
146- atomicAdd (reinterpret_cast <unsigned long long *>(output),
147- static_cast <unsigned long long >(valid_count));
148- }
149- }
133+ value = lane < block_warps ? warp_counts[lane] : 0 ;
134+ value = warp == 0 ? warp_sum (value) : 0 ;
135+ return thread == 0 ? value : 0 ;
150136}
151137
152138} // namespace
153139
154- // CUDA entry point for validity bitmap repacking used by Arrow Device export.
140+ // Repack sliced validity when the source bitmap offset does not match the
141+ // Arrow array offset. Each thread writes independent output words.
142+ //
143+ // thread 0 -> output word 0, word N, ...
144+ // thread 1 -> output word 1, word N+1, ...
155145extern " C" __global__ void arrow_validity_repack (const uint8_t *const input,
156146 uint64_t *const output,
157147 uint64_t len,
158148 uint64_t input_offset,
159149 uint64_t arrow_offset,
160150 uint64_t input_bytes) {
161- arrow_validity_repack_device (input, output, len, input_offset, arrow_offset, input_bytes);
151+ const uint64_t worker = blockIdx .x * blockDim .x + threadIdx .x ;
152+ const uint64_t validity_bits = len + arrow_offset;
153+ const uint64_t output_words = (validity_bits + 63 ) / 64 ;
154+ const uint64_t stride = static_cast <uint64_t >(gridDim .x ) * blockDim .x ;
155+ const int64_t shift = static_cast <int64_t >(input_offset) - static_cast <int64_t >(arrow_offset);
156+
157+ for (uint64_t word_idx = worker; word_idx < output_words; word_idx += stride) {
158+ output[word_idx] = repack_word (input, word_idx, shift, arrow_offset, validity_bits, input_bytes);
159+ }
162160}
163161
164- // Kernel entry point for counting valid bits in an Arrow validity bitmap.
162+ // Count valid rows directly from the device bitmap so Arrow export can provide
163+ // an exact null_count without copying validity to the CPU.
164+ //
165+ // bytes -> mask padding -> popcount -> block sum -> global count
165166extern " C" __global__ void arrow_validity_count_valid (const uint8_t *const input,
166167 uint64_t *const output,
167168 uint64_t len,
168169 uint64_t arrow_offset) {
169- arrow_validity_count_valid_device (input, output, len, arrow_offset);
170+ __shared__ uint64_t warp_counts[WARP_SIZE ];
171+
172+ const uint64_t validity_bits = len + arrow_offset;
173+ const uint64_t input_bytes = (validity_bits + 7 ) / 8 ;
174+ const uint64_t worker = blockIdx .x * blockDim .x + threadIdx .x ;
175+ const uint64_t stride = static_cast <uint64_t >(gridDim .x ) * blockDim .x ;
176+
177+ // Grid-stride over bitmap bytes. Each byte contributes the popcount of only
178+ // row bits; leading Arrow offset bits and trailing padding bits are masked out.
179+ uint64_t valid_count = 0 ;
180+ for (uint64_t byte_idx = worker; byte_idx < input_bytes; byte_idx += stride) {
181+ const uint32_t mask = arrow_validity_byte_mask (byte_idx, arrow_offset, validity_bits);
182+ valid_count += __popc (static_cast <uint32_t >(input[byte_idx]) & mask);
183+ }
184+
185+ // Reduce within the block first so global contention is one atomic add per block.
186+ valid_count = block_sum_to_thread_zero (valid_count, warp_counts);
187+ if (threadIdx .x == 0 ) {
188+ atomicAdd (reinterpret_cast <unsigned long long *>(output),
189+ static_cast <unsigned long long >(valid_count));
190+ }
170191}
0 commit comments