Skip to content

Commit 2f9a5e2

Browse files
Revert "Add extension constant pushdown rule and fix InnerProduct rule (#7507)"
This reverts commit 869b2d1.
1 parent 4d73f97 commit 2f9a5e2

2 files changed

Lines changed: 25 additions & 118 deletions

File tree

vortex-array/src/arrays/extension/compute/rules.rs

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -6,60 +6,25 @@ use vortex_error::VortexResult;
66
use crate::ArrayRef;
77
use crate::IntoArray;
88
use crate::array::ArrayView;
9-
use crate::arrays::Constant;
10-
use crate::arrays::ConstantArray;
119
use crate::arrays::Extension;
1210
use crate::arrays::ExtensionArray;
1311
use crate::arrays::Filter;
1412
use crate::arrays::extension::ExtensionArrayExt;
1513
use crate::arrays::filter::FilterReduceAdaptor;
1614
use crate::arrays::slice::SliceReduceAdaptor;
17-
use crate::matcher::AnyArray;
1815
use crate::optimizer::rules::ArrayParentReduceRule;
1916
use crate::optimizer::rules::ParentRuleSet;
20-
use crate::scalar::Scalar;
2117
use crate::scalar_fn::fns::cast::CastReduceAdaptor;
2218
use crate::scalar_fn::fns::mask::MaskReduceAdaptor;
2319

2420
pub(crate) const PARENT_RULES: ParentRuleSet<Extension> = ParentRuleSet::new(&[
25-
ParentRuleSet::lift(&ExtensionConstantParentRule),
2621
ParentRuleSet::lift(&ExtensionFilterPushDownRule),
2722
ParentRuleSet::lift(&CastReduceAdaptor(Extension)),
2823
ParentRuleSet::lift(&FilterReduceAdaptor(Extension)),
2924
ParentRuleSet::lift(&MaskReduceAdaptor(Extension)),
3025
ParentRuleSet::lift(&SliceReduceAdaptor(Extension)),
3126
]);
3227

33-
/// Normalize `Extension(Constant(storage))` children to `Constant(Extension(storage))`.
34-
#[derive(Debug)]
35-
struct ExtensionConstantParentRule;
36-
37-
impl ArrayParentReduceRule<Extension> for ExtensionConstantParentRule {
38-
type Parent = AnyArray;
39-
40-
fn reduce_parent(
41-
&self,
42-
child: ArrayView<'_, Extension>,
43-
parent: &ArrayRef,
44-
child_idx: usize,
45-
) -> VortexResult<Option<ArrayRef>> {
46-
let Some(const_array) = child.storage_array().as_opt::<Constant>() else {
47-
return Ok(None);
48-
};
49-
50-
let storage_scalar = const_array.scalar().clone();
51-
let ext_scalar = Scalar::extension_ref(child.ext_dtype().clone(), storage_scalar);
52-
53-
let constant_with_extension_scalar =
54-
ConstantArray::new(ext_scalar, child.len()).into_array();
55-
56-
parent
57-
.clone()
58-
.with_slot(child_idx, constant_with_extension_scalar)
59-
.map(Some)
60-
}
61-
}
62-
6328
/// Push filter operations into the storage array of an extension array.
6429
#[derive(Debug)]
6530
struct ExtensionFilterPushDownRule;
@@ -93,7 +58,6 @@ mod tests {
9358
use crate::IntoArray;
9459
#[expect(deprecated)]
9560
use crate::ToCanonical as _;
96-
use crate::arrays::Constant;
9761
use crate::arrays::ConstantArray;
9862
use crate::arrays::Extension;
9963
use crate::arrays::ExtensionArray;
@@ -213,31 +177,6 @@ mod tests {
213177
assert_eq!(canonical.len(), 3);
214178
}
215179

216-
#[test]
217-
fn test_extension_constant_child_normalizes_under_scalar_fn() {
218-
let ext_dtype = test_ext_dtype();
219-
220-
let constant_storage = ConstantArray::new(Scalar::from(10i64), 3).into_array();
221-
let constant_ext = ExtensionArray::new(ext_dtype.clone(), constant_storage).into_array();
222-
223-
let storage = buffer![15i64, 25, 35].into_array();
224-
let ext_array = ExtensionArray::new(ext_dtype, storage).into_array();
225-
226-
let scalar_fn_array = Binary
227-
.try_new_array(3, Operator::Lt, [constant_ext, ext_array])
228-
.unwrap();
229-
230-
let optimized = scalar_fn_array.optimize().unwrap();
231-
let scalar_fn = optimized.as_opt::<crate::arrays::ScalarFnVTable>().unwrap();
232-
let children = scalar_fn.children();
233-
let constant = children[0]
234-
.as_opt::<Constant>()
235-
.expect("constant extension child should be normalized");
236-
237-
assert!(constant.scalar().as_extension_opt().is_some());
238-
assert_eq!(constant.len(), 3);
239-
}
240-
241180
#[test]
242181
fn test_scalar_fn_no_pushdown_different_ext_types() {
243182
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]

vortex-tensor/src/scalar_fns/inner_product.rs

Lines changed: 25 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -448,18 +448,26 @@ impl InnerProduct {
448448
return Ok(None);
449449
}
450450

451-
// The other side must be a constant tensor.
452-
let Some(const_storage) = constant_tensor_storage(const_ref) else {
451+
// The other side must be a constant-backed tensor-like extension whose scalar is
452+
// non-null.
453+
let Some(const_ext) = const_ref.as_opt::<Extension>() else {
453454
return Ok(None);
454455
};
456+
let const_storage = const_ext.storage_array();
457+
let Some(const_backing) = const_storage.as_opt::<Constant>() else {
458+
return Ok(None);
459+
};
460+
if const_backing.scalar().is_null() {
461+
return Ok(None);
462+
}
455463

456464
let dim = sorf_view.options.dimension as usize;
457465
let num_rounds = sorf_view.options.num_rounds as usize;
458466
let seed = sorf_view.options.seed;
459467
let padded_dim = dim.next_power_of_two();
460468

461469
// Extract the single stored row of the constant via the stride-0 short-circuit.
462-
let flat = extract_flat_elements(&const_storage, dim, ctx)?;
470+
let flat = extract_flat_elements(const_storage, dim, ctx)?;
463471
if flat.ptype() != PType::F32 {
464472
// TODO(connor): as above, f16/f64 are not supported by this rewrite yet. The
465473
// standard path handles them correctly.
@@ -474,9 +482,9 @@ impl InnerProduct {
474482
let mut rotated_query = vec![0.0f32; padded_dim];
475483
rotation.rotate(&padded_query, &mut rotated_query);
476484

477-
// Build the rewritten constant as a `Vector<padded_dim, f32>` extension scalar. We reuse
478-
// the original storage FSL nullability so the new extension dtype stays consistent with
479-
// whatever the original tree expected.
485+
// Build the rewritten constant as a `Vector<padded_dim, f32>` extension wrapping a
486+
// `ConstantArray` of length `len`. We reuse the original storage FSL nullability so
487+
// the new extension dtype stays consistent with whatever the original tree expected.
480488
let storage_fsl_nullability = const_storage.dtype().nullability();
481489
let element_dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
482490
let children: Vec<Scalar> = rotated_query
@@ -485,6 +493,7 @@ impl InnerProduct {
485493
.collect();
486494
let fsl_scalar =
487495
Scalar::fixed_size_list(element_dtype.clone(), children, storage_fsl_nullability);
496+
let new_storage = ConstantArray::new(fsl_scalar, len).into_array();
488497

489498
// Build a fresh `Vector<padded_dim, f32>` extension dtype. We cannot reuse the
490499
// original extension dtype because that one has `dim`, not `padded_dim`.
@@ -495,8 +504,7 @@ impl InnerProduct {
495504
storage_fsl_nullability,
496505
);
497506
let new_ext_dtype = ExtDType::<Vector>::try_new(EmptyMetadata, new_fsl_dtype)?.erased();
498-
let new_constant =
499-
ConstantArray::new(Scalar::extension_ref(new_ext_dtype, fsl_scalar), len).into_array();
507+
let new_constant = ExtensionArray::new(new_ext_dtype, new_storage).into_array();
500508

501509
// Extract the SorfTransform child (the already-padded Vector<padded_dim, f32>).
502510
let sorf_child = sorf_view
@@ -564,9 +572,16 @@ impl InnerProduct {
564572
};
565573

566574
// Navigate the constant side and require its scalar be non-null.
567-
let Some(const_storage) = constant_tensor_storage(const_candidate) else {
575+
let Some(const_ext) = const_candidate.as_opt::<Extension>() else {
568576
return Ok(None);
569577
};
578+
let const_storage = const_ext.storage_array();
579+
let Some(const_backing) = const_storage.as_opt::<Constant>() else {
580+
return Ok(None);
581+
};
582+
if const_backing.scalar().is_null() {
583+
return Ok(None);
584+
}
570585

571586
// Canonicalize codes and values. Codes may be e.g. BitPacked; executing is cheaper
572587
// than falling through to the standard path (which would also canonicalize).
@@ -587,7 +602,7 @@ impl InnerProduct {
587602

588603
let padded_dim = usize::try_from(fsl.list_size()).vortex_expect("fsl list_size fits usize");
589604

590-
let flat = extract_flat_elements(&const_storage, padded_dim, ctx)?;
605+
let flat = extract_flat_elements(const_storage, padded_dim, ctx)?;
591606
if flat.ptype() != PType::F32 {
592607
// TODO(connor): case 2 is f32-only. For f16/f64 we fall through to the standard
593608
// path, which computes the inner product with the correct element type.
@@ -622,16 +637,6 @@ impl InnerProduct {
622637
}
623638
}
624639

625-
/// Return the storage constant for a canonical tensor-like constant query.
626-
fn constant_tensor_storage(array: &ArrayRef) -> Option<ArrayRef> {
627-
let constant = array.as_opt::<Constant>()?;
628-
if constant.scalar().is_null() {
629-
return None;
630-
}
631-
let ext_scalar = constant.scalar().as_extension_opt()?;
632-
Some(ConstantArray::new(ext_scalar.to_storage_scalar(), array.len()).into_array())
633-
}
634-
635640
/// Computes the inner product (dot product) of two equal-length float slices.
636641
///
637642
/// Returns `sum(a_i * b_i)`.
@@ -954,7 +959,6 @@ mod tests {
954959
use vortex_array::ArrayRef;
955960
use vortex_array::IntoArray;
956961
use vortex_array::VortexSessionExecute;
957-
use vortex_array::arrays::Constant;
958962
use vortex_array::arrays::ConstantArray;
959963
use vortex_array::arrays::ExtensionArray;
960964
use vortex_array::arrays::FixedSizeListArray;
@@ -974,11 +978,9 @@ mod tests {
974978
use vortex_session::VortexSession;
975979

976980
use crate::scalar_fns::inner_product::InnerProduct;
977-
use crate::scalar_fns::inner_product::constant_tensor_storage;
978981
use crate::scalar_fns::sorf_transform::SorfMatrix;
979982
use crate::scalar_fns::sorf_transform::SorfOptions;
980983
use crate::scalar_fns::sorf_transform::SorfTransform;
981-
use crate::utils::extract_flat_elements;
982984
use crate::vector::Vector;
983985

984986
static SESSION: LazyLock<VortexSession> =
@@ -1009,19 +1011,6 @@ mod tests {
10091011
Ok(ExtensionArray::new(ext_dtype, storage).into_array())
10101012
}
10111013

1012-
/// Expression-literal shape: a ConstantArray whose scalar itself is a Vector extension.
1013-
fn literal_vector_f32(elements: &[f32], len: usize) -> ArrayRef {
1014-
let element_dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
1015-
let children: Vec<Scalar> = elements
1016-
.iter()
1017-
.map(|&v| Scalar::primitive(v, Nullability::NonNullable))
1018-
.collect();
1019-
let storage_scalar =
1020-
Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable);
1021-
let vector_scalar = Scalar::extension::<Vector>(EmptyMetadata, storage_scalar);
1022-
ConstantArray::new(vector_scalar, len).into_array()
1023-
}
1024-
10251014
/// Build an `ExtensionArray<Vector<list_size, f32>>` whose storage is
10261015
/// `FSL(DictArray(codes: u8, values: f32))`. This mirrors the shape that
10271016
/// TurboQuant produces as the SorfTransform child.
@@ -1126,27 +1115,6 @@ mod tests {
11261115

11271116
// ---- Case 1: SorfTransform + Constant pull-through ----
11281117

1129-
#[test]
1130-
fn constant_tensor_storage_accepts_extension_scalar_literal() -> VortexResult<()> {
1131-
let literal = literal_vector_f32(&[1.0, 2.0, 3.0], 5);
1132-
let storage =
1133-
constant_tensor_storage(&literal).expect("literal vector should be recognized");
1134-
1135-
assert_eq!(storage.len(), 5);
1136-
let const_storage = storage
1137-
.as_opt::<Constant>()
1138-
.expect("storage should remain constant-backed");
1139-
assert!(matches!(
1140-
const_storage.scalar().dtype(),
1141-
DType::FixedSizeList(_, 3, Nullability::NonNullable)
1142-
));
1143-
1144-
let mut ctx = SESSION.create_execution_ctx();
1145-
let flat = extract_flat_elements(&storage, 3, &mut ctx)?;
1146-
assert_eq!(flat.row::<f32>(0), &[1.0, 2.0, 3.0]);
1147-
Ok(())
1148-
}
1149-
11501118
/// Case 1: SorfTransform on LHS, constant query on RHS, with `dim < padded_dim`
11511119
/// so the zero-padding branch is exercised.
11521120
#[test]

0 commit comments

Comments
 (0)