Skip to content

Commit 4a0ed9b

Browse files
authored
feat(cuda): add SEQUENCE source op to dyn dispatch (#7078)
Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent 4829e13 commit 4a0ed9b

4 files changed

Lines changed: 109 additions & 1 deletion

File tree

vortex-cuda/kernels/src/dynamic_dispatch.cu

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,17 @@ __device__ inline void dynamic_source_op(const T *__restrict input,
114114
return;
115115
}
116116

117+
case SourceOp::SEQUENCE: {
118+
// Generate a linear sequence: value[i] = base + i * multiplier.
119+
// Used for SequenceArray (e.g. monotonic run-end endpoints).
120+
const T base = static_cast<T>(source_op.params.sequence.base);
121+
const T mul = static_cast<T>(source_op.params.sequence.multiplier);
122+
for (uint32_t i = threadIdx.x; i < chunk_len; i += blockDim.x) {
123+
smem_output[i] = base + static_cast<T>(chunk_start + i) * mul;
124+
}
125+
break;
126+
}
127+
117128
default:
118129
__builtin_unreachable();
119130
}

vortex-cuda/kernels/src/dynamic_dispatch.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,16 @@ union SourceParams {
6060
uint64_t num_runs;
6161
uint64_t offset;
6262
} runend;
63+
64+
/// Generate a linear sequence: `value[i] = base + i * multiplier`.
65+
struct SequenceParams {
66+
int64_t base;
67+
int64_t multiplier;
68+
} sequence;
6369
};
6470

6571
struct SourceOp {
66-
enum SourceOpCode { BITUNPACK, LOAD, RUNEND } op_code;
72+
enum SourceOpCode { BITUNPACK, LOAD, RUNEND, SEQUENCE } op_code;
6773
union SourceParams params;
6874
};
6975

vortex-cuda/src/dynamic_dispatch/mod.rs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,17 @@ impl SourceOp {
7575
},
7676
}
7777
}
78+
79+
/// Generate a linear sequence: `value[i] = base + i * multiplier`.
80+
/// Used for SequenceArray (e.g. monotonic run-end endpoints).
81+
pub fn sequence(base: i64, multiplier: i64) -> Self {
82+
Self {
83+
op_code: SourceOp_SourceOpCode_SEQUENCE,
84+
params: SourceParams {
85+
sequence: SourceParams_SequenceParams { base, multiplier },
86+
},
87+
}
88+
}
7889
}
7990

8091
impl ScalarOp {
@@ -1004,4 +1015,60 @@ mod tests {
10041015

10051016
Ok(())
10061017
}
1018+
1019+
#[rstest]
1020+
#[case(0u32, 1u32, 100)]
1021+
#[case(5u32, 3u32, 2048)]
1022+
#[case(0u32, 1u32, 4096)]
1023+
#[case(100u32, 7u32, 5000)]
1024+
#[crate::test]
1025+
fn test_sequence_unsigned(
1026+
#[case] base: u32,
1027+
#[case] multiplier: u32,
1028+
#[case] len: usize,
1029+
) -> VortexResult<()> {
1030+
use vortex::dtype::Nullability;
1031+
use vortex::encodings::sequence::SequenceArray;
1032+
1033+
let expected: Vec<u32> = (0..len).map(|i| base + (i as u32) * multiplier).collect();
1034+
1035+
let seq = SequenceArray::try_new_typed(base, multiplier, Nullability::NonNullable, len)?;
1036+
1037+
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
1038+
let (plan, _bufs) = build_plan(&seq.into_array(), &cuda_ctx)?;
1039+
1040+
let actual = run_dynamic_dispatch_plan(&cuda_ctx, expected.len(), &plan)?;
1041+
assert_eq!(actual, expected);
1042+
1043+
Ok(())
1044+
}
1045+
1046+
#[rstest]
1047+
#[case(0i32, 1i32, 100)]
1048+
#[case(-10i32, 3i32, 2048)]
1049+
#[case(100i32, -1i32, 100)]
1050+
#[case(-500i32, -7i32, 50)]
1051+
#[case(0i32, 1i32, 5000)]
1052+
#[crate::test]
1053+
fn test_sequence_signed(
1054+
#[case] base: i32,
1055+
#[case] multiplier: i32,
1056+
#[case] len: usize,
1057+
) -> VortexResult<()> {
1058+
use vortex::dtype::Nullability;
1059+
use vortex::encodings::sequence::SequenceArray;
1060+
1061+
let expected: Vec<i32> = (0..len).map(|i| base + (i as i32) * multiplier).collect();
1062+
1063+
let seq = SequenceArray::try_new_typed(base, multiplier, Nullability::NonNullable, len)?;
1064+
1065+
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
1066+
let (plan, _bufs) = build_plan(&seq.into_array(), &cuda_ctx)?;
1067+
1068+
let actual_u32 = run_dynamic_dispatch_plan(&cuda_ctx, expected.len(), &plan)?;
1069+
let actual: Vec<i32> = actual_u32.into_iter().map(|v| v as i32).collect();
1070+
assert_eq!(actual, expected);
1071+
1072+
Ok(())
1073+
}
10071074
}

vortex-cuda/src/dynamic_dispatch/plan_builder.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ use vortex::encodings::fastlanes::FoR;
2626
use vortex::encodings::fastlanes::FoRArray;
2727
use vortex::encodings::runend::RunEnd;
2828
use vortex::encodings::runend::RunEndArrayParts;
29+
use vortex::encodings::sequence::Sequence;
30+
use vortex::encodings::sequence::SequenceArrayParts;
2931
use vortex::encodings::zigzag::ZigZag;
3032
use vortex::error::VortexResult;
3133
use vortex::error::vortex_bail;
@@ -82,6 +84,7 @@ struct Pipeline {
8284
/// - `ALPArray` → recurse + `ALP` scalar op (f32 only, no patches)
8385
/// - `DictArray` → input stage for values + recurse codes + `DICT` scalar op
8486
/// - `RunEndArray` → input stages for ends/values + `RUNEND` source
87+
/// - `SequenceArray` → `SEQUENCE` source (integer ptypes only)
8588
/// - `SliceArray` → resolve via child's slice reduce/kernel
8689
///
8790
/// # Limitations
@@ -158,6 +161,8 @@ impl PlanBuilderState<'_> {
158161
self.walk_primitive(array)
159162
} else if id == Slice::ID {
160163
self.walk_slice(array)
164+
} else if id == Sequence::ID {
165+
self.walk_sequence(array)
161166
} else {
162167
vortex_bail!(
163168
"Encoding {:?} not supported by dynamic dispatch plan builder",
@@ -305,6 +310,25 @@ impl PlanBuilderState<'_> {
305310
Ok(pipeline)
306311
}
307312

313+
/// SequenceArray → SEQUENCE source op
314+
///
315+
/// Generates `value[i] = base + i * multiplier` on the GPU.
316+
fn walk_sequence(&mut self, array: ArrayRef) -> VortexResult<Pipeline> {
317+
let seq = array
318+
.try_into::<Sequence>()
319+
.map_err(|_| vortex_err!("Expected SequenceArray"))?;
320+
let SequenceArrayParts {
321+
base, multiplier, ..
322+
} = seq.into_parts();
323+
324+
Ok(Pipeline {
325+
source: SourceOp::sequence(base.cast()?, multiplier.cast()?),
326+
scalar_ops: vec![],
327+
// SEQUENCE does not have an input pointer.
328+
input_ptr: 0,
329+
})
330+
}
331+
308332
/// RunEndArray → add input stages for ends and values, RUNEND source op.
309333
fn walk_runend(&mut self, array: ArrayRef) -> VortexResult<Pipeline> {
310334
let re = array

0 commit comments

Comments
 (0)