Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions vortex-cuda/src/dynamic_dispatch/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub use plan_builder::MaterializedPlan;

include!(concat!(env!("OUT_DIR"), "/dynamic_dispatch.rs"));

/// Reinterpret a `&T` as a byte slice for serialization into the packed plan.
/// Reinterpret a reference to a packed type as a raw bytes slice.
///
/// # Safety
///
Expand All @@ -58,10 +58,17 @@ include!(concat!(env!("OUT_DIR"), "/dynamic_dispatch.rs"));
/// `PackedStage`, `ScalarOp`) are bindgen-generated `#[repr(C)]` structs.
/// Padding bytes may be uninitialised on the Rust side, but the C reader
/// never inspects them, so the values are irrelevant.
fn as_bytes<T: Sized>(val: &T) -> &[u8] {
unsafe { from_raw_parts(std::ptr::addr_of!(*val).cast(), size_of::<T>()) }
unsafe trait AsPackedBytes: Sized {
/// Reinterpret a `&T` as a byte slice for serialization into the packed plan.
fn as_packed_bytes(&self) -> &[u8] {
unsafe { from_raw_parts(std::ptr::addr_of!(*self).cast(), size_of::<Self>()) }
}
}

unsafe impl AsPackedBytes for PlanHeader {}
unsafe impl AsPackedBytes for PackedStage {}
unsafe impl AsPackedBytes for ScalarOp {}

/// A stage used to build a [`CudaDispatchPlan`] on the host side.
///
/// This is NOT a C ABI struct — it exists purely on the Rust side and is
Expand Down Expand Up @@ -165,7 +172,7 @@ impl CudaDispatchPlan {
num_stages: stages.len() as u8,
plan_size_bytes: total_size as u16,
};
buffer.extend_from_slice(as_bytes(&header));
buffer.extend_from_slice(header.as_packed_bytes());

// Write each stage header followed by its scalar ops.
for stage in &stages {
Expand All @@ -176,9 +183,9 @@ impl CudaDispatchPlan {
source: stage.source,
num_scalar_ops: stage.scalar_ops.len() as u8,
};
buffer.extend_from_slice(as_bytes(&packed_stage));
buffer.extend_from_slice(packed_stage.as_packed_bytes());
for op in &stage.scalar_ops {
buffer.extend_from_slice(as_bytes(op));
buffer.extend_from_slice(op.as_packed_bytes());
}
}

Expand Down
Loading