Skip to content

Commit 137a6c3

Browse files
committed
add extract constant flat row
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent a0c6252 commit 137a6c3

4 files changed

Lines changed: 78 additions & 27 deletions

File tree

vortex-tensor/public-api.lock

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ impl vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'_>
142142

143143
pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'_>::element_ptype(&self) -> vortex_array::dtype::ptype::PType
144144

145-
pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'_>::list_size(&self) -> usize
145+
pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'_>::flat_list_size(&self) -> u32
146146

147147
pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'_>::metadata(&self) -> &vortex_tensor::fixed_shape::FixedShapeTensorMetadata
148148

@@ -222,7 +222,7 @@ impl vortex_tensor::matcher::TensorMatch<'_>
222222

223223
pub fn vortex_tensor::matcher::TensorMatch<'_>::element_ptype(self) -> vortex_array::dtype::ptype::PType
224224

225-
pub fn vortex_tensor::matcher::TensorMatch<'_>::list_size(self) -> usize
225+
pub fn vortex_tensor::matcher::TensorMatch<'_>::list_size(self) -> u32
226226

227227
impl<'a> core::clone::Clone for vortex_tensor::matcher::TensorMatch<'a>
228228

@@ -382,7 +382,7 @@ pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::validity(&self, _options:
382382

383383
pub fn vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm(input: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::vtable::ScalarFnArray>
384384

385-
pub fn vortex_tensor::scalar_fns::l2_denorm::validate_l2_normalized_rows(input: &vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<()>
385+
pub fn vortex_tensor::scalar_fns::l2_denorm::validate_l2_normalized_rows_against_norms(normalized: &vortex_array::array::erased::ArrayRef, norms: core::option::Option<&vortex_array::array::erased::ArrayRef>, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<()>
386386

387387
pub mod vortex_tensor::scalar_fns::l2_norm
388388

@@ -570,12 +570,10 @@ pub struct vortex_tensor::vector::VectorMatcherMetadata
570570

571571
impl vortex_tensor::vector::VectorMatcherMetadata
572572

573-
pub fn vortex_tensor::vector::VectorMatcherMetadata::dimensions(&self) -> usize
573+
pub fn vortex_tensor::vector::VectorMatcherMetadata::dimensions(&self) -> u32
574574

575575
pub fn vortex_tensor::vector::VectorMatcherMetadata::element_ptype(&self) -> vortex_array::dtype::ptype::PType
576576

577-
pub fn vortex_tensor::vector::VectorMatcherMetadata::list_size(&self) -> u32
578-
579577
pub fn vortex_tensor::vector::VectorMatcherMetadata::try_new(element_ptype: vortex_array::dtype::ptype::PType, dimensions: u32) -> vortex_error::VortexResult<Self>
580578

581579
impl core::clone::Clone for vortex_tensor::vector::VectorMatcherMetadata

vortex-tensor/src/scalar_fns/inner_product.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ use crate::matcher::AnyTensor;
5555
use crate::scalar_fns::l2_denorm::DenormOrientation;
5656
use crate::scalar_fns::sorf_transform::SorfMatrix;
5757
use crate::scalar_fns::sorf_transform::SorfTransform;
58+
use crate::utils::extract_constant_flat_row;
5859
use crate::utils::extract_flat_elements;
5960
use crate::utils::extract_l2_denorm_children;
6061
use crate::utils::validate_binary_tensor_float_inputs;
@@ -425,15 +426,15 @@ impl InnerProduct {
425426
let seed = sorf_view.options.seed;
426427
let padded_dim = dim.next_power_of_two();
427428

428-
// Extract the single stored row of the constant via the `is_constant` short-circuit.
429-
let flat = extract_flat_elements(&const_storage, dim, ctx)?;
429+
// Extract the single stored row of the constant.
430+
let flat = extract_constant_flat_row(&const_storage, ctx)?;
430431
if flat.ptype() != PType::F32 {
431432
return Ok(None);
432433
}
433434

434435
// Zero-pad the query from `dim` to `padded_dim` and forward-rotate.
435436
let mut padded_query = vec![0.0f32; padded_dim];
436-
padded_query[..dim].copy_from_slice(flat.row::<f32>(0));
437+
padded_query[..dim].copy_from_slice(flat.as_slice::<f32>());
437438

438439
let rotation = SorfMatrix::try_new(seed, dim, num_rounds)?;
439440
let mut rotated_query = vec![0.0f32; padded_dim];
@@ -533,7 +534,7 @@ impl InnerProduct {
533534

534535
let padded_dim = usize::try_from(fsl.list_size()).vortex_expect("fsl list_size fits usize");
535536

536-
let flat = extract_flat_elements(&const_storage, padded_dim, ctx)?;
537+
let flat = extract_constant_flat_row(&const_storage, ctx)?;
537538
if flat.ptype() != PType::F32 {
538539
// TODO(connor): case 2 is f32-only. For f16/f64 we fall through to the standard
539540
// path, which computes the inner product with the correct element type.
@@ -552,7 +553,7 @@ impl InnerProduct {
552553
return Ok(Some(empty.into_array()));
553554
}
554555

555-
let q: &[f32] = flat.row::<f32>(0);
556+
let q: &[f32] = flat.as_slice::<f32>();
556557
debug_assert_eq!(q.len(), padded_dim);
557558
let codes: &[u8] = codes_prim.as_slice::<u8>();
558559
let values: &[f32] = values_prim.as_slice::<f32>();

vortex-tensor/src/scalar_fns/l2_denorm.rs

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ use vortex_session::VortexSession;
6060

6161
use crate::matcher::AnyTensor;
6262
use crate::scalar_fns::l2_norm::L2Norm;
63+
use crate::utils::extract_constant_flat_row;
6364
use crate::utils::extract_flat_elements;
6465
use crate::utils::validate_tensor_float_input;
6566

@@ -525,12 +526,12 @@ pub(crate) fn try_build_constant_l2_denorm(
525526
let ext_dtype = input.dtype().as_extension().clone();
526527
let storage_fsl_nullability = storage.dtype().nullability();
527528

528-
// `extract_flat_elements` takes the `is_constant` single-row path for `Constant` storage, so
529-
// this is cheap and does not expand the constant to the full column length.
530-
let flat = extract_flat_elements(storage, list_size, ctx)?;
529+
// Materialize just the single stored row; this does not expand the constant to the full
530+
// column length.
531+
let flat = extract_constant_flat_row(storage, ctx)?;
531532

532533
let (normalized_fsl_scalar, norms_scalar) = match_each_float_ptype!(flat.ptype(), |T| {
533-
let row = flat.row::<T>(0);
534+
let row = flat.as_slice::<T>();
534535

535536
let mut sum_sq = T::zero();
536537
for &x in row {
@@ -605,19 +606,14 @@ fn unit_norm_tolerance(element_ptype: PType) -> f64 {
605606
}
606607
}
607608

608-
/// Validates that every valid row of `input` is already L2-normalized (either length 1 or 0).
609-
pub fn validate_l2_normalized_rows(input: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
610-
validate_l2_normalized_rows_against_norms(input, None, ctx)
611-
}
612-
613609
/// Validates that `normalized` and (when supplied) the matching `norms` jointly satisfy the
614610
/// [`L2Denorm`] invariants:
615611
///
616612
/// - Every valid row of `normalized` has L2 norm `1.0` or `0.0` (within element-precision
617613
/// tolerance).
618614
/// - When `norms` is supplied, every stored norm is non-negative and any row whose stored norm is
619615
/// `0.0` is exactly the zero vector in `normalized`.
620-
fn validate_l2_normalized_rows_against_norms(
616+
pub fn validate_l2_normalized_rows_against_norms(
621617
normalized: &ArrayRef,
622618
norms: Option<&ArrayRef>,
623619
ctx: &mut ExecutionCtx,
@@ -771,7 +767,7 @@ mod tests {
771767

772768
use crate::scalar_fns::l2_denorm::L2Denorm;
773769
use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm;
774-
use crate::scalar_fns::l2_denorm::validate_l2_normalized_rows;
770+
use crate::scalar_fns::l2_denorm::validate_l2_normalized_rows_against_norms;
775771
use crate::tests::SESSION;
776772
use crate::utils::test_helpers::assert_close;
777773
use crate::utils::test_helpers::constant_tensor_array;
@@ -906,15 +902,15 @@ mod tests {
906902
let input = vector_array(2, &[3.0f32, 4.0, 0.0, 0.0].map(half::f16::from_f32))?;
907903
let mut ctx = SESSION.create_execution_ctx();
908904
let roundtrip = normalize_as_l2_denorm(input, &mut ctx)?;
909-
validate_l2_normalized_rows(&roundtrip.child_at(0).clone(), &mut ctx)?;
905+
validate_l2_normalized_rows_against_norms(&roundtrip.child_at(0).clone(), None, &mut ctx)?;
910906
Ok(())
911907
}
912908

913909
#[test]
914910
fn validate_l2_normalized_rows_rejects_unnormalized_input() -> VortexResult<()> {
915911
let input = vector_array(2, &[3.0, 4.0, 1.0, 0.0])?;
916912
let mut ctx = SESSION.create_execution_ctx();
917-
let result = validate_l2_normalized_rows(&input, &mut ctx);
913+
let result = validate_l2_normalized_rows_against_norms(&input, None, &mut ctx);
918914
assert!(result.is_err());
919915
Ok(())
920916
}

vortex-tensor/src/utils.rs

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use vortex_array::arrays::ConstantArray;
99
use vortex_array::arrays::FixedSizeListArray;
1010
use vortex_array::arrays::PrimitiveArray;
1111
use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt;
12+
use vortex_array::arrays::primitive::PrimitiveArrayExt;
1213
use vortex_array::arrays::scalar_fn::ExactScalarFn;
1314
use vortex_array::dtype::DType;
1415
use vortex_array::dtype::NativePType;
@@ -124,12 +125,11 @@ impl FlatElements {
124125
}
125126
}
126127

127-
// TODO(connor): Usage of this function is sometimes incorrect / not performant.
128-
// Make sure to fix them.
129128
/// Extracts the flat primitive elements from a tensor storage array (FixedSizeList).
130129
///
131130
/// When the input is a [`ConstantArray`] (e.g., a literal query vector), only a single row is
132-
/// materialized to avoid expanding it to the full column length.
131+
/// materialized to avoid expanding it to the full column length. Callers that have already
132+
/// confirmed the storage is constant-backed should prefer [`extract_constant_flat_row`].
133133
pub fn extract_flat_elements(
134134
storage: &ArrayRef,
135135
list_size: usize,
@@ -146,13 +146,69 @@ pub fn extract_flat_elements(
146146

147147
let fsl: FixedSizeListArray = source.execute(ctx)?;
148148
let elems: PrimitiveArray = fsl.elements().clone().execute(ctx)?;
149+
vortex_ensure!(
150+
!elems.nullability().is_nullable(),
151+
"tensor storage elements must be non-nullable, got {}",
152+
elems.dtype(),
153+
);
149154
Ok(FlatElements {
150155
elems,
151156
list_size,
152157
is_constant,
153158
})
154159
}
155160

161+
/// The single stored row of a constant-backed tensor storage array.
162+
///
163+
/// Contrast with [`FlatElements`], which exposes arbitrary row indices: a `FlatRow` statically
164+
/// encodes "there is exactly one row available," so call sites that have gated on a constant input
165+
/// read the row via [`Self::as_slice`] instead of `row(0)`.
166+
pub struct FlatRow {
167+
elems: PrimitiveArray,
168+
}
169+
170+
impl FlatRow {
171+
/// Returns the [`PType`] of the underlying elements.
172+
#[must_use]
173+
pub fn ptype(&self) -> PType {
174+
self.elems.ptype()
175+
}
176+
177+
/// Returns the stored row as a typed slice. Its length equals the storage scalar's
178+
/// fixed-size-list size.
179+
#[must_use]
180+
pub fn as_slice<T: NativePType>(&self) -> &[T] {
181+
self.elems.as_slice::<T>()
182+
}
183+
}
184+
185+
/// Extracts the single stored row from a [`Constant`]-backed tensor storage array.
186+
///
187+
/// The caller must have confirmed that `storage` is a [`Constant`] encoding whose scalar is a
188+
/// non-null fixed-size list. This is the fast path for constant query vectors: exactly one row is
189+
/// materialized regardless of the column length.
190+
///
191+
/// # Panics
192+
///
193+
/// Panics if `storage` is not a [`Constant`] encoding.
194+
pub fn extract_constant_flat_row(
195+
storage: &ArrayRef,
196+
ctx: &mut ExecutionCtx,
197+
) -> VortexResult<FlatRow> {
198+
let constant = storage
199+
.as_opt::<Constant>()
200+
.vortex_expect("extract_constant_flat_row requires Constant-backed storage");
201+
let single = ConstantArray::new(constant.scalar().clone(), 1).into_array();
202+
let fsl: FixedSizeListArray = single.execute(ctx)?;
203+
let elems: PrimitiveArray = fsl.elements().clone().execute(ctx)?;
204+
vortex_ensure!(
205+
!elems.nullability().is_nullable(),
206+
"tensor storage elements must be non-nullable, got {}",
207+
elems.dtype(),
208+
);
209+
Ok(FlatRow { elems })
210+
}
211+
156212
/// Extracts the `(normalized, norms)` children from an [`L2Denorm`] scalar function array.
157213
///
158214
/// [`L2Denorm`]: crate::scalar_fns::l2_denorm::L2Denorm

0 commit comments

Comments
 (0)