diff --git a/vortex-array/src/arrays/extension/compute/rules.rs b/vortex-array/src/arrays/extension/compute/rules.rs index 6a58e4838be..7408488a0f1 100644 --- a/vortex-array/src/arrays/extension/compute/rules.rs +++ b/vortex-array/src/arrays/extension/compute/rules.rs @@ -6,18 +6,23 @@ use vortex_error::VortexResult; use crate::ArrayRef; use crate::IntoArray; use crate::array::ArrayView; +use crate::arrays::Constant; +use crate::arrays::ConstantArray; use crate::arrays::Extension; use crate::arrays::ExtensionArray; use crate::arrays::Filter; use crate::arrays::extension::ExtensionArrayExt; use crate::arrays::filter::FilterReduceAdaptor; use crate::arrays::slice::SliceReduceAdaptor; +use crate::matcher::AnyArray; use crate::optimizer::rules::ArrayParentReduceRule; use crate::optimizer::rules::ParentRuleSet; +use crate::scalar::Scalar; use crate::scalar_fn::fns::cast::CastReduceAdaptor; use crate::scalar_fn::fns::mask::MaskReduceAdaptor; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ + ParentRuleSet::lift(&ExtensionConstantParentRule), ParentRuleSet::lift(&ExtensionFilterPushDownRule), ParentRuleSet::lift(&CastReduceAdaptor(Extension)), ParentRuleSet::lift(&FilterReduceAdaptor(Extension)), @@ -25,6 +30,36 @@ pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&SliceReduceAdaptor(Extension)), ]); +/// Normalize `Extension(Constant(storage))` children to `Constant(Extension(storage))`. +#[derive(Debug)] +struct ExtensionConstantParentRule; + +impl ArrayParentReduceRule for ExtensionConstantParentRule { + type Parent = AnyArray; + + fn reduce_parent( + &self, + child: ArrayView<'_, Extension>, + parent: &ArrayRef, + child_idx: usize, + ) -> VortexResult> { + let Some(const_array) = child.storage_array().as_opt::() else { + return Ok(None); + }; + + let storage_scalar = const_array.scalar().clone(); + let ext_scalar = Scalar::extension_ref(child.ext_dtype().clone(), storage_scalar); + + let constant_with_extension_scalar = + ConstantArray::new(ext_scalar, child.len()).into_array(); + + parent + .clone() + .with_slot(child_idx, constant_with_extension_scalar) + .map(Some) + } +} + /// Push filter operations into the storage array of an extension array. #[derive(Debug)] struct ExtensionFilterPushDownRule; @@ -58,6 +93,7 @@ mod tests { use crate::IntoArray; #[expect(deprecated)] use crate::ToCanonical as _; + use crate::arrays::Constant; use crate::arrays::ConstantArray; use crate::arrays::Extension; use crate::arrays::ExtensionArray; @@ -177,6 +213,31 @@ mod tests { assert_eq!(canonical.len(), 3); } + #[test] + fn test_extension_constant_child_normalizes_under_scalar_fn() { + let ext_dtype = test_ext_dtype(); + + let constant_storage = ConstantArray::new(Scalar::from(10i64), 3).into_array(); + let constant_ext = ExtensionArray::new(ext_dtype.clone(), constant_storage).into_array(); + + let storage = buffer![15i64, 25, 35].into_array(); + let ext_array = ExtensionArray::new(ext_dtype, storage).into_array(); + + let scalar_fn_array = Binary + .try_new_array(3, Operator::Lt, [constant_ext, ext_array]) + .unwrap(); + + let optimized = scalar_fn_array.optimize().unwrap(); + let scalar_fn = optimized.as_opt::().unwrap(); + let children = scalar_fn.children(); + let constant = children[0] + .as_opt::() + .expect("constant extension child should be normalized"); + + assert!(constant.scalar().as_extension_opt().is_some()); + assert_eq!(constant.len(), 3); + } + #[test] fn test_scalar_fn_no_pushdown_different_ext_types() { #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index 5928335ccf8..dd9c2a7381f 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -448,18 +448,10 @@ impl InnerProduct { return Ok(None); } - // The other side must be a constant-backed tensor-like extension whose scalar is - // non-null. - let Some(const_ext) = const_ref.as_opt::() else { + // The other side must be a constant tensor. + let Some(const_storage) = constant_tensor_storage(const_ref) else { return Ok(None); }; - let const_storage = const_ext.storage_array(); - let Some(const_backing) = const_storage.as_opt::() else { - return Ok(None); - }; - if const_backing.scalar().is_null() { - return Ok(None); - } let dim = sorf_view.options.dimension as usize; let num_rounds = sorf_view.options.num_rounds as usize; @@ -467,7 +459,7 @@ impl InnerProduct { let padded_dim = dim.next_power_of_two(); // Extract the single stored row of the constant via the stride-0 short-circuit. - let flat = extract_flat_elements(const_storage, dim, ctx)?; + let flat = extract_flat_elements(&const_storage, dim, ctx)?; if flat.ptype() != PType::F32 { // TODO(connor): as above, f16/f64 are not supported by this rewrite yet. The // standard path handles them correctly. @@ -482,9 +474,9 @@ impl InnerProduct { let mut rotated_query = vec![0.0f32; padded_dim]; rotation.rotate(&padded_query, &mut rotated_query); - // Build the rewritten constant as a `Vector` extension wrapping a - // `ConstantArray` of length `len`. We reuse the original storage FSL nullability so - // the new extension dtype stays consistent with whatever the original tree expected. + // Build the rewritten constant as a `Vector` extension scalar. We reuse + // the original storage FSL nullability so the new extension dtype stays consistent with + // whatever the original tree expected. let storage_fsl_nullability = const_storage.dtype().nullability(); let element_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); let children: Vec = rotated_query @@ -493,7 +485,6 @@ impl InnerProduct { .collect(); let fsl_scalar = Scalar::fixed_size_list(element_dtype.clone(), children, storage_fsl_nullability); - let new_storage = ConstantArray::new(fsl_scalar, len).into_array(); // Build a fresh `Vector` extension dtype. We cannot reuse the // original extension dtype because that one has `dim`, not `padded_dim`. @@ -504,7 +495,8 @@ impl InnerProduct { storage_fsl_nullability, ); let new_ext_dtype = ExtDType::::try_new(EmptyMetadata, new_fsl_dtype)?.erased(); - let new_constant = ExtensionArray::new(new_ext_dtype, new_storage).into_array(); + let new_constant = + ConstantArray::new(Scalar::extension_ref(new_ext_dtype, fsl_scalar), len).into_array(); // Extract the SorfTransform child (the already-padded Vector). let sorf_child = sorf_view @@ -572,16 +564,9 @@ impl InnerProduct { }; // Navigate the constant side and require its scalar be non-null. - let Some(const_ext) = const_candidate.as_opt::() else { + let Some(const_storage) = constant_tensor_storage(const_candidate) else { return Ok(None); }; - let const_storage = const_ext.storage_array(); - let Some(const_backing) = const_storage.as_opt::() else { - return Ok(None); - }; - if const_backing.scalar().is_null() { - return Ok(None); - } // Canonicalize codes and values. Codes may be e.g. BitPacked; executing is cheaper // than falling through to the standard path (which would also canonicalize). @@ -602,7 +587,7 @@ impl InnerProduct { let padded_dim = usize::try_from(fsl.list_size()).vortex_expect("fsl list_size fits usize"); - let flat = extract_flat_elements(const_storage, padded_dim, ctx)?; + let flat = extract_flat_elements(&const_storage, padded_dim, ctx)?; if flat.ptype() != PType::F32 { // TODO(connor): case 2 is f32-only. For f16/f64 we fall through to the standard // path, which computes the inner product with the correct element type. @@ -637,6 +622,16 @@ impl InnerProduct { } } +/// Return the storage constant for a canonical tensor-like constant query. +fn constant_tensor_storage(array: &ArrayRef) -> Option { + let constant = array.as_opt::()?; + if constant.scalar().is_null() { + return None; + } + let ext_scalar = constant.scalar().as_extension_opt()?; + Some(ConstantArray::new(ext_scalar.to_storage_scalar(), array.len()).into_array()) +} + /// Computes the inner product (dot product) of two equal-length float slices. /// /// Returns `sum(a_i * b_i)`. @@ -959,6 +954,7 @@ mod tests { use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; + use vortex_array::arrays::Constant; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; @@ -978,9 +974,11 @@ mod tests { use vortex_session::VortexSession; use crate::scalar_fns::inner_product::InnerProduct; + use crate::scalar_fns::inner_product::constant_tensor_storage; use crate::scalar_fns::sorf_transform::SorfMatrix; use crate::scalar_fns::sorf_transform::SorfOptions; use crate::scalar_fns::sorf_transform::SorfTransform; + use crate::utils::extract_flat_elements; use crate::vector::Vector; static SESSION: LazyLock = @@ -1011,6 +1009,19 @@ mod tests { Ok(ExtensionArray::new(ext_dtype, storage).into_array()) } + /// Expression-literal shape: a ConstantArray whose scalar itself is a Vector extension. + fn literal_vector_f32(elements: &[f32], len: usize) -> ArrayRef { + let element_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); + let children: Vec = elements + .iter() + .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) + .collect(); + let storage_scalar = + Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); + let vector_scalar = Scalar::extension::(EmptyMetadata, storage_scalar); + ConstantArray::new(vector_scalar, len).into_array() + } + /// Build an `ExtensionArray>` whose storage is /// `FSL(DictArray(codes: u8, values: f32))`. This mirrors the shape that /// TurboQuant produces as the SorfTransform child. @@ -1115,6 +1126,27 @@ mod tests { // ---- Case 1: SorfTransform + Constant pull-through ---- + #[test] + fn constant_tensor_storage_accepts_extension_scalar_literal() -> VortexResult<()> { + let literal = literal_vector_f32(&[1.0, 2.0, 3.0], 5); + let storage = + constant_tensor_storage(&literal).expect("literal vector should be recognized"); + + assert_eq!(storage.len(), 5); + let const_storage = storage + .as_opt::() + .expect("storage should remain constant-backed"); + assert!(matches!( + const_storage.scalar().dtype(), + DType::FixedSizeList(_, 3, Nullability::NonNullable) + )); + + let mut ctx = SESSION.create_execution_ctx(); + let flat = extract_flat_elements(&storage, 3, &mut ctx)?; + assert_eq!(flat.row::(0), &[1.0, 2.0, 3.0]); + Ok(()) + } + /// Case 1: SorfTransform on LHS, constant query on RHS, with `dim < padded_dim` /// so the zero-padding branch is exercised. #[test]