Skip to content

Commit f1d3124

Browse files
committed
clean up
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 62513c9 commit f1d3124

3 files changed

Lines changed: 44 additions & 29 deletions

File tree

vortex-tensor/src/scalar_fns/cosine_similarity.rs

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -138,22 +138,12 @@ impl ScalarFnVTable for CosineSimilarity {
138138
let lhs_storage = extension_storage(&lhs)?;
139139
let rhs_storage = extension_storage(&rhs)?;
140140

141-
// Extract the flat primitive elements from each tensor column. When an input is a
142-
// `ConstantArray` (e.g., a literal query vector), we materialize only a single row
143-
// instead of expanding it to the full row count.
144-
let (lhs_elems, lhs_stride) = extract_flat_elements(&lhs_storage, list_size)?;
145-
let (rhs_elems, rhs_stride) = extract_flat_elements(&rhs_storage, list_size)?;
146-
147-
match_each_float_ptype!(lhs_elems.ptype(), |T| {
148-
let lhs_slice = lhs_elems.as_slice::<T>();
149-
let rhs_slice = rhs_elems.as_slice::<T>();
141+
let lhs_flat = extract_flat_elements(&lhs_storage, list_size)?;
142+
let rhs_flat = extract_flat_elements(&rhs_storage, list_size)?;
150143

144+
match_each_float_ptype!(lhs_flat.ptype(), |T| {
151145
let result: PrimitiveArray = (0..row_count)
152-
.map(|i| {
153-
let a = &lhs_slice[i * lhs_stride..i * lhs_stride + list_size];
154-
let b = &rhs_slice[i * rhs_stride..i * rhs_stride + list_size];
155-
cosine_similarity_row(a, b)
156-
})
146+
.map(|i| cosine_similarity_row(lhs_flat.row::<T>(i), rhs_flat.row::<T>(i)))
157147
.collect();
158148

159149
Ok(result.into_array())

vortex-tensor/src/scalar_fns/l2_norm.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,16 +116,11 @@ impl ScalarFnVTable for L2Norm {
116116
let list_size = extension_list_size(ext)?;
117117

118118
let storage = extension_storage(&input)?;
119-
let (elems, stride) = extract_flat_elements(&storage, list_size)?;
120-
121-
match_each_float_ptype!(elems.ptype(), |T| {
122-
let slice = elems.as_slice::<T>();
119+
let flat = extract_flat_elements(&storage, list_size)?;
123120

121+
match_each_float_ptype!(flat.ptype(), |T| {
124122
let result: PrimitiveArray = (0..row_count)
125-
.map(|i| {
126-
let v = &slice[i * stride..i * stride + list_size];
127-
l2_norm_row(v)
128-
})
123+
.map(|i| l2_norm_row(flat.row::<T>(i)))
129124
.collect();
130125

131126
Ok(result.into_array())

vortex-tensor/src/scalar_fns/utils.rs

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use vortex::array::arrays::ConstantArray;
88
use vortex::array::arrays::Extension;
99
use vortex::array::arrays::PrimitiveArray;
1010
use vortex::dtype::DType;
11+
use vortex::dtype::NativePType;
1112
use vortex::dtype::PType;
1213
use vortex::dtype::extension::ExtDTypeRef;
1314
use vortex::error::VortexResult;
@@ -60,28 +61,57 @@ pub(crate) fn extension_storage(array: &ArrayRef) -> VortexResult<ArrayRef> {
6061
Ok(ext.storage_array().clone())
6162
}
6263

63-
// TODO(connor): it would be nicer if this took a generic parameter and a FnMut arg that we run
64-
// directly on the values without having to return this ugly stride.
64+
/// The flat primitive elements of a tensor storage array, with typed row access.
65+
///
66+
/// This struct hides the stride detail that arises from the [`ConstantArray`] optimization: a
67+
/// constant input materializes only a single row (stride=0), while a full array uses
68+
/// stride=list_size.
69+
pub(crate) struct FlatElements {
70+
elems: PrimitiveArray,
71+
stride: usize,
72+
list_size: usize,
73+
}
74+
75+
impl FlatElements {
76+
/// Returns the [`PType`] of the underlying elements.
77+
pub fn ptype(&self) -> PType {
78+
self.elems.ptype()
79+
}
80+
81+
/// Returns the `i`-th row as a typed slice of length `list_size`.
82+
pub fn row<T: NativePType>(&self, i: usize) -> &[T] {
83+
let slice = self.elems.as_slice::<T>();
84+
&slice[i * self.stride..i * self.stride + self.list_size]
85+
}
86+
}
87+
6588
/// Extracts the flat primitive elements from a tensor storage array (FixedSizeList).
6689
///
6790
/// When the input is a [`ConstantArray`] (e.g., a literal query vector), only a single row is
68-
/// materialized to avoid expanding it to the full column length. Returns `(elements, stride)`
69-
/// where `stride` is `list_size` for a full array and `0` for a constant.
91+
/// materialized to avoid expanding it to the full column length.
7092
pub(crate) fn extract_flat_elements(
7193
storage: &ArrayRef,
7294
list_size: usize,
73-
) -> VortexResult<(PrimitiveArray, usize)> {
95+
) -> VortexResult<FlatElements> {
7496
if let Some(constant) = storage.as_opt::<Constant>() {
7597
// Rewrite the array as a length 1 array so when we canonicalize, we do not duplicate a huge
7698
// amount of data.
7799
let single = ConstantArray::new(constant.scalar().clone(), 1).into_array();
78100
let fsl = single.to_canonical()?.into_fixed_size_list();
79101
let elems = fsl.elements().to_canonical()?.into_primitive();
80-
return Ok((elems, 0));
102+
return Ok(FlatElements {
103+
elems,
104+
stride: 0,
105+
list_size,
106+
});
81107
}
82108

83109
// Otherwise we have to fully expand all of the data.
84110
let fsl = storage.to_canonical()?.into_fixed_size_list();
85111
let elems = fsl.elements().to_canonical()?.into_primitive();
86-
Ok((elems, list_size))
112+
Ok(FlatElements {
113+
elems,
114+
stride: list_size,
115+
list_size,
116+
})
87117
}

0 commit comments

Comments
 (0)