diff --git a/vortex-cuda/src/dynamic_dispatch/mod.rs b/vortex-cuda/src/dynamic_dispatch/mod.rs index 4cf2560a766..90a848d268a 100644 --- a/vortex-cuda/src/dynamic_dispatch/mod.rs +++ b/vortex-cuda/src/dynamic_dispatch/mod.rs @@ -49,7 +49,7 @@ pub use plan_builder::MaterializedPlan; include!(concat!(env!("OUT_DIR"), "/dynamic_dispatch.rs")); -/// Reinterpret a `&T` as a byte slice for serialization into the packed plan. +/// Reinterpret a reference to a packed type as a raw bytes slice. /// /// # Safety /// @@ -58,10 +58,17 @@ include!(concat!(env!("OUT_DIR"), "/dynamic_dispatch.rs")); /// `PackedStage`, `ScalarOp`) are bindgen-generated `#[repr(C)]` structs. /// Padding bytes may be uninitialised on the Rust side, but the C reader /// never inspects them, so the values are irrelevant. -fn as_bytes(val: &T) -> &[u8] { - unsafe { from_raw_parts(std::ptr::addr_of!(*val).cast(), size_of::()) } +unsafe trait AsPackedBytes: Sized { + /// Reinterpret a `&T` as a byte slice for serialization into the packed plan. + fn as_packed_bytes(&self) -> &[u8] { + unsafe { from_raw_parts(std::ptr::addr_of!(*self).cast(), size_of::()) } + } } +unsafe impl AsPackedBytes for PlanHeader {} +unsafe impl AsPackedBytes for PackedStage {} +unsafe impl AsPackedBytes for ScalarOp {} + /// A stage used to build a [`CudaDispatchPlan`] on the host side. /// /// This is NOT a C ABI struct — it exists purely on the Rust side and is @@ -165,7 +172,7 @@ impl CudaDispatchPlan { num_stages: stages.len() as u8, plan_size_bytes: total_size as u16, }; - buffer.extend_from_slice(as_bytes(&header)); + buffer.extend_from_slice(header.as_packed_bytes()); // Write each stage header followed by its scalar ops. for stage in &stages { @@ -176,9 +183,9 @@ impl CudaDispatchPlan { source: stage.source, num_scalar_ops: stage.scalar_ops.len() as u8, }; - buffer.extend_from_slice(as_bytes(&packed_stage)); + buffer.extend_from_slice(packed_stage.as_packed_bytes()); for op in &stage.scalar_ops { - buffer.extend_from_slice(as_bytes(op)); + buffer.extend_from_slice(op.as_packed_bytes()); } }