@@ -448,18 +448,26 @@ impl InnerProduct {
448448 return Ok ( None ) ;
449449 }
450450
451- // The other side must be a constant tensor.
452- let Some ( const_storage) = constant_tensor_storage ( const_ref) else {
451+ // The other side must be a constant-backed tensor-like extension whose scalar is
452+ // non-null.
453+ let Some ( const_ext) = const_ref. as_opt :: < Extension > ( ) else {
453454 return Ok ( None ) ;
454455 } ;
456+ let const_storage = const_ext. storage_array ( ) ;
457+ let Some ( const_backing) = const_storage. as_opt :: < Constant > ( ) else {
458+ return Ok ( None ) ;
459+ } ;
460+ if const_backing. scalar ( ) . is_null ( ) {
461+ return Ok ( None ) ;
462+ }
455463
456464 let dim = sorf_view. options . dimension as usize ;
457465 let num_rounds = sorf_view. options . num_rounds as usize ;
458466 let seed = sorf_view. options . seed ;
459467 let padded_dim = dim. next_power_of_two ( ) ;
460468
461469 // Extract the single stored row of the constant via the stride-0 short-circuit.
462- let flat = extract_flat_elements ( & const_storage, dim, ctx) ?;
470+ let flat = extract_flat_elements ( const_storage, dim, ctx) ?;
463471 if flat. ptype ( ) != PType :: F32 {
464472 // TODO(connor): as above, f16/f64 are not supported by this rewrite yet. The
465473 // standard path handles them correctly.
@@ -474,9 +482,9 @@ impl InnerProduct {
474482 let mut rotated_query = vec ! [ 0.0f32 ; padded_dim] ;
475483 rotation. rotate ( & padded_query, & mut rotated_query) ;
476484
477- // Build the rewritten constant as a `Vector<padded_dim, f32>` extension scalar. We reuse
478- // the original storage FSL nullability so the new extension dtype stays consistent with
479- // whatever the original tree expected.
485+ // Build the rewritten constant as a `Vector<padded_dim, f32>` extension wrapping a
486+ // `ConstantArray` of length `len`. We reuse the original storage FSL nullability so
487+ // the new extension dtype stays consistent with whatever the original tree expected.
480488 let storage_fsl_nullability = const_storage. dtype ( ) . nullability ( ) ;
481489 let element_dtype = DType :: Primitive ( PType :: F32 , Nullability :: NonNullable ) ;
482490 let children: Vec < Scalar > = rotated_query
@@ -485,6 +493,7 @@ impl InnerProduct {
485493 . collect ( ) ;
486494 let fsl_scalar =
487495 Scalar :: fixed_size_list ( element_dtype. clone ( ) , children, storage_fsl_nullability) ;
496+ let new_storage = ConstantArray :: new ( fsl_scalar, len) . into_array ( ) ;
488497
489498 // Build a fresh `Vector<padded_dim, f32>` extension dtype. We cannot reuse the
490499 // original extension dtype because that one has `dim`, not `padded_dim`.
@@ -495,8 +504,7 @@ impl InnerProduct {
495504 storage_fsl_nullability,
496505 ) ;
497506 let new_ext_dtype = ExtDType :: < Vector > :: try_new ( EmptyMetadata , new_fsl_dtype) ?. erased ( ) ;
498- let new_constant =
499- ConstantArray :: new ( Scalar :: extension_ref ( new_ext_dtype, fsl_scalar) , len) . into_array ( ) ;
507+ let new_constant = ExtensionArray :: new ( new_ext_dtype, new_storage) . into_array ( ) ;
500508
501509 // Extract the SorfTransform child (the already-padded Vector<padded_dim, f32>).
502510 let sorf_child = sorf_view
@@ -564,9 +572,16 @@ impl InnerProduct {
564572 } ;
565573
566574 // Navigate the constant side and require its scalar be non-null.
567- let Some ( const_storage ) = constant_tensor_storage ( const_candidate) else {
575+ let Some ( const_ext ) = const_candidate. as_opt :: < Extension > ( ) else {
568576 return Ok ( None ) ;
569577 } ;
578+ let const_storage = const_ext. storage_array ( ) ;
579+ let Some ( const_backing) = const_storage. as_opt :: < Constant > ( ) else {
580+ return Ok ( None ) ;
581+ } ;
582+ if const_backing. scalar ( ) . is_null ( ) {
583+ return Ok ( None ) ;
584+ }
570585
571586 // Canonicalize codes and values. Codes may be e.g. BitPacked; executing is cheaper
572587 // than falling through to the standard path (which would also canonicalize).
@@ -587,7 +602,7 @@ impl InnerProduct {
587602
588603 let padded_dim = usize:: try_from ( fsl. list_size ( ) ) . vortex_expect ( "fsl list_size fits usize" ) ;
589604
590- let flat = extract_flat_elements ( & const_storage, padded_dim, ctx) ?;
605+ let flat = extract_flat_elements ( const_storage, padded_dim, ctx) ?;
591606 if flat. ptype ( ) != PType :: F32 {
592607 // TODO(connor): case 2 is f32-only. For f16/f64 we fall through to the standard
593608 // path, which computes the inner product with the correct element type.
@@ -622,16 +637,6 @@ impl InnerProduct {
622637 }
623638}
624639
625- /// Return the storage constant for a canonical tensor-like constant query.
626- fn constant_tensor_storage ( array : & ArrayRef ) -> Option < ArrayRef > {
627- let constant = array. as_opt :: < Constant > ( ) ?;
628- if constant. scalar ( ) . is_null ( ) {
629- return None ;
630- }
631- let ext_scalar = constant. scalar ( ) . as_extension_opt ( ) ?;
632- Some ( ConstantArray :: new ( ext_scalar. to_storage_scalar ( ) , array. len ( ) ) . into_array ( ) )
633- }
634-
635640/// Computes the inner product (dot product) of two equal-length float slices.
636641///
637642/// Returns `sum(a_i * b_i)`.
@@ -954,7 +959,6 @@ mod tests {
954959 use vortex_array:: ArrayRef ;
955960 use vortex_array:: IntoArray ;
956961 use vortex_array:: VortexSessionExecute ;
957- use vortex_array:: arrays:: Constant ;
958962 use vortex_array:: arrays:: ConstantArray ;
959963 use vortex_array:: arrays:: ExtensionArray ;
960964 use vortex_array:: arrays:: FixedSizeListArray ;
@@ -974,11 +978,9 @@ mod tests {
974978 use vortex_session:: VortexSession ;
975979
976980 use crate :: scalar_fns:: inner_product:: InnerProduct ;
977- use crate :: scalar_fns:: inner_product:: constant_tensor_storage;
978981 use crate :: scalar_fns:: sorf_transform:: SorfMatrix ;
979982 use crate :: scalar_fns:: sorf_transform:: SorfOptions ;
980983 use crate :: scalar_fns:: sorf_transform:: SorfTransform ;
981- use crate :: utils:: extract_flat_elements;
982984 use crate :: vector:: Vector ;
983985
984986 static SESSION : LazyLock < VortexSession > =
@@ -1009,19 +1011,6 @@ mod tests {
10091011 Ok ( ExtensionArray :: new ( ext_dtype, storage) . into_array ( ) )
10101012 }
10111013
1012- /// Expression-literal shape: a ConstantArray whose scalar itself is a Vector extension.
1013- fn literal_vector_f32 ( elements : & [ f32 ] , len : usize ) -> ArrayRef {
1014- let element_dtype = DType :: Primitive ( PType :: F32 , Nullability :: NonNullable ) ;
1015- let children: Vec < Scalar > = elements
1016- . iter ( )
1017- . map ( |& v| Scalar :: primitive ( v, Nullability :: NonNullable ) )
1018- . collect ( ) ;
1019- let storage_scalar =
1020- Scalar :: fixed_size_list ( element_dtype, children, Nullability :: NonNullable ) ;
1021- let vector_scalar = Scalar :: extension :: < Vector > ( EmptyMetadata , storage_scalar) ;
1022- ConstantArray :: new ( vector_scalar, len) . into_array ( )
1023- }
1024-
10251014 /// Build an `ExtensionArray<Vector<list_size, f32>>` whose storage is
10261015 /// `FSL(DictArray(codes: u8, values: f32))`. This mirrors the shape that
10271016 /// TurboQuant produces as the SorfTransform child.
@@ -1126,27 +1115,6 @@ mod tests {
11261115
11271116 // ---- Case 1: SorfTransform + Constant pull-through ----
11281117
1129- #[ test]
1130- fn constant_tensor_storage_accepts_extension_scalar_literal ( ) -> VortexResult < ( ) > {
1131- let literal = literal_vector_f32 ( & [ 1.0 , 2.0 , 3.0 ] , 5 ) ;
1132- let storage =
1133- constant_tensor_storage ( & literal) . expect ( "literal vector should be recognized" ) ;
1134-
1135- assert_eq ! ( storage. len( ) , 5 ) ;
1136- let const_storage = storage
1137- . as_opt :: < Constant > ( )
1138- . expect ( "storage should remain constant-backed" ) ;
1139- assert ! ( matches!(
1140- const_storage. scalar( ) . dtype( ) ,
1141- DType :: FixedSizeList ( _, 3 , Nullability :: NonNullable )
1142- ) ) ;
1143-
1144- let mut ctx = SESSION . create_execution_ctx ( ) ;
1145- let flat = extract_flat_elements ( & storage, 3 , & mut ctx) ?;
1146- assert_eq ! ( flat. row:: <f32 >( 0 ) , & [ 1.0 , 2.0 , 3.0 ] ) ;
1147- Ok ( ( ) )
1148- }
1149-
11501118 /// Case 1: SorfTransform on LHS, constant query on RHS, with `dim < padded_dim`
11511119 /// so the zero-padding branch is exercised.
11521120 #[ test]
0 commit comments