diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 687b9d33fef..4a63f15090d 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -3358,7 +3358,7 @@ pub fn vortex_array::arrays::patched::Patched::child_name(array: vortex_array::A pub fn vortex_array::arrays::patched::Patched::deserialize(&self, dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], _buffers: &[vortex_array::buffer::BufferHandle], children: &dyn vortex_array::serde::ArrayChildren, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult> -pub fn vortex_array::arrays::patched::Patched::execute(array: vortex_array::Array, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_array::arrays::patched::Patched::execute(array: vortex_array::Array, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_array::arrays::patched::Patched::execute_parent(array: vortex_array::ArrayView<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> @@ -6198,7 +6198,7 @@ pub fn vortex_array::arrays::patched::Patched::child_name(array: vortex_array::A pub fn vortex_array::arrays::patched::Patched::deserialize(&self, dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], _buffers: &[vortex_array::buffer::BufferHandle], children: &dyn vortex_array::serde::ArrayChildren, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult> -pub fn vortex_array::arrays::patched::Patched::execute(array: vortex_array::Array, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_array::arrays::patched::Patched::execute(array: vortex_array::Array, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_array::arrays::patched::Patched::execute_parent(array: vortex_array::ArrayView<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> @@ -19856,7 +19856,7 @@ pub fn vortex_array::arrays::patched::Patched::child_name(array: vortex_array::A pub fn vortex_array::arrays::patched::Patched::deserialize(&self, dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], _buffers: &[vortex_array::buffer::BufferHandle], children: &dyn vortex_array::serde::ArrayChildren, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult> -pub fn vortex_array::arrays::patched::Patched::execute(array: vortex_array::Array, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_array::arrays::patched::Patched::execute(array: vortex_array::Array, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_array::arrays::patched::Patched::execute_parent(array: vortex_array::ArrayView<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> @@ -20872,7 +20872,7 @@ pub fn vortex_array::arrays::patched::Patched::child_name(array: vortex_array::A pub fn vortex_array::arrays::patched::Patched::deserialize(&self, dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], _buffers: &[vortex_array::buffer::BufferHandle], children: &dyn vortex_array::serde::ArrayChildren, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult> -pub fn vortex_array::arrays::patched::Patched::execute(array: vortex_array::Array, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_array::arrays::patched::Patched::execute(array: vortex_array::Array, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_array::arrays::patched::Patched::execute_parent(array: vortex_array::ArrayView<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> @@ -23576,7 +23576,7 @@ pub fn vortex_array::arrays::patched::Patched::child_name(array: vortex_array::A pub fn vortex_array::arrays::patched::Patched::deserialize(&self, dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], _buffers: &[vortex_array::buffer::BufferHandle], children: &dyn vortex_array::serde::ArrayChildren, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult> -pub fn vortex_array::arrays::patched::Patched::execute(array: vortex_array::Array, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_array::arrays::patched::Patched::execute(array: vortex_array::Array, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_array::arrays::patched::Patched::execute_parent(array: vortex_array::ArrayView<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> @@ -24840,7 +24840,7 @@ pub fn vortex_array::arrays::patched::Patched::child_name(array: vortex_array::A pub fn vortex_array::arrays::patched::Patched::deserialize(&self, dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], _buffers: &[vortex_array::buffer::BufferHandle], children: &dyn vortex_array::serde::ArrayChildren, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult> -pub fn vortex_array::arrays::patched::Patched::execute(array: vortex_array::Array, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_array::arrays::patched::Patched::execute(array: vortex_array::Array, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_array::arrays::patched::Patched::execute_parent(array: vortex_array::ArrayView<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> diff --git a/vortex-array/src/arrays/patched/vtable/mod.rs b/vortex-array/src/arrays/patched/vtable/mod.rs index 5b8f900c7f0..9e5c2a6e77c 100644 --- a/vortex-array/src/arrays/patched/vtable/mod.rs +++ b/vortex-array/src/arrays/patched/vtable/mod.rs @@ -31,10 +31,15 @@ use crate::array::ArrayView; use crate::array::VTable; use crate::array::ValidityChild; use crate::array::ValidityVTableFromChild; +use crate::arrays::Primitive; use crate::arrays::PrimitiveArray; use crate::arrays::patched::PatchedArrayExt; use crate::arrays::patched::PatchedData; +use crate::arrays::patched::array::INDICES_SLOT; +use crate::arrays::patched::array::INNER_SLOT; +use crate::arrays::patched::array::LANE_OFFSETS_SLOT; use crate::arrays::patched::array::SLOT_NAMES; +use crate::arrays::patched::array::VALUES_SLOT; use crate::arrays::patched::compute::rules::PARENT_RULES; use crate::arrays::patched::vtable::kernels::PARENT_KERNELS; use crate::arrays::primitive::PrimitiveDataParts; @@ -45,6 +50,7 @@ use crate::dtype::DType; use crate::dtype::NativePType; use crate::dtype::PType; use crate::match_each_native_ptype; +use crate::require_child; use crate::serde::ArrayChildren; /// A [`Patched`]-encoded Vortex array. @@ -242,12 +248,42 @@ impl VTable for Patched { SLOT_NAMES[idx].to_string() } - fn execute(array: Array, ctx: &mut ExecutionCtx) -> VortexResult { - let inner = array - .base_array() - .clone() - .execute::(ctx)? - .into_primitive(); + fn execute(array: Array, _ctx: &mut ExecutionCtx) -> VortexResult { + let array = require_child!(array, array.base_array(), INNER_SLOT => Primitive); + let array = require_child!(array, array.lane_offsets(), LANE_OFFSETS_SLOT => Primitive); + let array = require_child!(array, array.patch_indices(), INDICES_SLOT => Primitive); + let array = require_child!(array, array.patch_values(), VALUES_SLOT => Primitive); + + let len = array.len(); + + fn take_slot(slots: &mut [Option], idx: usize) -> ArrayRef { + slots[idx].take().vortex_expect("slot must be present") + } + + fn downcast_slot(slot: ArrayRef) -> PrimitiveArray { + slot.try_downcast::() + .ok() + .vortex_expect("slot must be primitive") + } + + let (n_lanes, offset, inner, lane_offsets, indices, values) = match array.try_into_parts() { + Ok(mut parts) => { + let PatchedData { n_lanes, offset } = parts.data; + let inner = downcast_slot(take_slot(&mut parts.slots, INNER_SLOT)); + let lane_offsets = downcast_slot(take_slot(&mut parts.slots, LANE_OFFSETS_SLOT)); + let indices = downcast_slot(take_slot(&mut parts.slots, INDICES_SLOT)); + let values = downcast_slot(take_slot(&mut parts.slots, VALUES_SLOT)); + (n_lanes, offset, inner, lane_offsets, indices, values) + } + Err(array) => { + let PatchedData { n_lanes, offset } = array.data().clone(); + let inner = downcast_slot(array.base_array().clone()); + let lane_offsets = downcast_slot(array.lane_offsets().clone()); + let indices = downcast_slot(array.patch_indices().clone()); + let values = downcast_slot(array.patch_values().clone()); + (n_lanes, offset, inner, lane_offsets, indices, values) + } + }; let PrimitiveDataParts { buffer, @@ -255,32 +291,14 @@ impl VTable for Patched { validity, } = inner.into_data_parts(); - let lane_offsets = array - .lane_offsets() - .clone() - .execute::(ctx)?; - let indices = array - .patch_indices() - .clone() - .execute::(ctx)?; - - // TODO(aduffy): add support for non-primitive PatchedArray patches application (?) - let values = array - .patch_values() - .clone() - .execute::(ctx)?; - let patched_values = match_each_native_ptype!(values.ptype(), |V| { - let offset = array.offset(); - let len = array.len(); - let mut output = Buffer::::from_byte_buffer(buffer.unwrap_host()).into_mut(); apply_patches_primitive::( &mut output, offset, len, - array.n_lanes(), + n_lanes, lane_offsets.as_slice::(), indices.as_slice::(), values.as_slice::(),