Skip to content

Commit 810d306

Browse files
committed
clean up
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 73e9e34 commit 810d306

5 files changed

Lines changed: 173 additions & 241 deletions

File tree

vortex-array/src/scalar_fn/vtable.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync {
9292
Ok(args.to_vec())
9393
}
9494

95+
// TODO(connor): This needs a precondition for the number of args it has, or all implementations
96+
// need to return an error if it is wrong.
9597
/// Compute the return [`DType`] of the expression if evaluated over the given input types.
9698
///
9799
/// # Preconditions

vortex-tensor/src/scalar_fns/cosine_similarity.rs

Lines changed: 29 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use vortex::dtype::Nullability;
1919
use vortex::dtype::extension::Matcher;
2020
use vortex::error::VortexResult;
2121
use vortex::error::vortex_ensure;
22+
use vortex::error::vortex_ensure_eq;
2223
use vortex::error::vortex_err;
2324
use vortex::expr::Expression;
2425
use vortex::scalar_fn::Arity;
@@ -81,33 +82,38 @@ impl ScalarFnVTable for CosineSimilarity {
8182
}
8283

8384
fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
84-
debug_assert_eq!(arg_dtypes.len(), 2);
85+
vortex_ensure_eq!(
86+
arg_dtypes.len(),
87+
2,
88+
"CosineSimilarity requires exactly 2 arguments, got {}",
89+
arg_dtypes.len()
90+
);
8591

8692
let lhs = &arg_dtypes[0];
8793
let rhs = &arg_dtypes[1];
8894

8995
// Both must have the same dtype (ignoring top-level nullability).
9096
vortex_ensure!(
9197
lhs.eq_ignore_nullability(rhs),
92-
"cosine_similarity requires both inputs to have the same dtype, got {lhs} and {rhs}"
98+
"CosineSimilarity requires both inputs to have the same dtype, got {lhs} and {rhs}"
9399
);
94100

95101
// We don't need to look at rhs anymore since we know lhs and rhs are equal.
96102

97103
// Both inputs must be tensor-like extension types.
98104
let lhs_ext = lhs.as_extension_opt().ok_or_else(|| {
99-
vortex_err!("cosine_similarity lhs must be an extension type, got {lhs}")
105+
vortex_err!("CosineSimilarity lhs must be an extension type, got {lhs}")
100106
})?;
101107

102108
vortex_ensure!(
103109
AnyTensor::matches(lhs_ext),
104-
"cosine_similarity inputs must be an `AnyTensor`, got {lhs}"
110+
"CosineSimilarity inputs must be an `AnyTensor`, got {lhs}"
105111
);
106112

107113
let ptype = extension_element_ptype(lhs_ext)?;
108114
vortex_ensure!(
109115
ptype.is_float(),
110-
"cosine_similarity element dtype must be a float primitive, got {ptype}"
116+
"CosineSimilarity element dtype must be a float primitive, got {ptype}"
111117
);
112118

113119
let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable());
@@ -190,79 +196,32 @@ fn cosine_similarity_row<T: Float + NativePType>(a: &[T], b: &[T]) -> T {
190196

191197
#[cfg(test)]
192198
mod tests {
193-
use vortex::array::ArrayRef;
194-
use vortex::array::IntoArray;
199+
use rstest::rstest;
195200
use vortex::array::ToCanonical;
196-
use vortex::array::arrays::ConstantArray;
197-
use vortex::array::arrays::ExtensionArray;
198-
use vortex::array::arrays::FixedSizeListArray;
199201
use vortex::array::arrays::ScalarFnArray;
200-
use vortex::array::validity::Validity;
201-
use vortex::buffer::Buffer;
202-
use vortex::dtype::DType;
203-
use vortex::dtype::Nullability;
204-
use vortex::dtype::extension::ExtDType;
205202
use vortex::error::VortexResult;
206-
use vortex::extension::EmptyMetadata;
207-
use vortex::scalar::Scalar;
208203
use vortex::scalar_fn::EmptyOptions;
209204
use vortex::scalar_fn::ScalarFn;
210205

211-
use crate::fixed_shape::FixedShapeTensor;
212-
use crate::fixed_shape::FixedShapeTensorMetadata;
213206
use crate::scalar_fns::cosine_similarity::CosineSimilarity;
214-
use crate::vector::Vector;
215-
216-
/// Builds a [`FixedShapeTensor`] extension array from flat f64 elements and a logical shape.
217-
///
218-
/// The number of rows is inferred from the total element count divided by the product of the
219-
/// shape dimensions. For 0-dimensional tensors (scalar), each element is one row.
220-
fn tensor_array(shape: &[usize], elements: &[f64]) -> VortexResult<ArrayRef> {
221-
let list_size: u32 = shape.iter().product::<usize>().max(1).try_into().unwrap();
222-
let row_count = elements.len() / list_size as usize;
223-
224-
let elems: ArrayRef = Buffer::copy_from(elements).into_array();
225-
let fsl = FixedSizeListArray::new(elems, list_size, Validity::NonNullable, row_count);
226-
227-
let metadata = FixedShapeTensorMetadata::new(shape.to_vec());
228-
let ext_dtype =
229-
ExtDType::<FixedShapeTensor>::try_new(metadata, fsl.dtype().clone())?.erased();
230-
231-
Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array())
232-
}
207+
use crate::scalar_fns::utils::test_helpers::assert_close;
208+
use crate::scalar_fns::utils::test_helpers::constant_tensor_array;
209+
use crate::scalar_fns::utils::test_helpers::constant_vector_array;
210+
use crate::scalar_fns::utils::test_helpers::tensor_array;
211+
use crate::scalar_fns::utils::test_helpers::vector_array;
233212

234213
/// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec<f64>`.
235-
fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult<Vec<f64>> {
214+
fn eval_cosine_similarity(
215+
lhs: vortex::array::ArrayRef,
216+
rhs: vortex::array::ArrayRef,
217+
len: usize,
218+
) -> VortexResult<Vec<f64>> {
236219
let scalar_fn = ScalarFn::new(CosineSimilarity, EmptyOptions).erased();
237220
let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], len)?;
238221
let prim = result.to_primitive();
239222
Ok(prim.as_slice::<f64>().to_vec())
240223
}
241224

242-
/// Asserts that each element in `actual` is within `1e-10` of the corresponding `expected`
243-
/// value, with support for NaN (NaN == NaN is considered equal).
244-
#[track_caller]
245-
fn assert_close(actual: &[f64], expected: &[f64]) {
246-
assert_eq!(
247-
actual.len(),
248-
expected.len(),
249-
"length mismatch: got {} elements, expected {}",
250-
actual.len(),
251-
expected.len()
252-
);
253-
254-
for (i, (a, e)) in actual.iter().zip(expected).enumerate() {
255-
if a.is_nan() && e.is_nan() {
256-
continue;
257-
}
258-
assert!(
259-
(a - e).abs() < 1e-10,
260-
"element {i}: got {a}, expected {e} (diff = {})",
261-
(a - e).abs()
262-
);
263-
}
264-
}
265-
266225
#[test]
267226
fn unit_vectors_1d() -> VortexResult<()> {
268227
let lhs = tensor_array(
@@ -280,20 +239,18 @@ mod tests {
280239
],
281240
)?;
282241

283-
// Row 0: identical 1.0, row 1: orthogonal 0.0.
242+
// Row 0: identical -> 1.0, row 1: orthogonal -> 0.0.
284243
assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]);
285244
Ok(())
286245
}
287246

288-
use rstest::rstest;
289-
290247
/// Single-row cosine similarity for various vector pairs.
291248
#[rstest]
292-
// Antiparallel -1.0.
249+
// Antiparallel -> -1.0.
293250
#[case::opposite(&[3], &[1.0, 0.0, 0.0], &[-1.0, 0.0, 0.0], &[-1.0])]
294-
// dot=24, both magnitudes=5 24/25 = 0.96.
251+
// dot=24, both magnitudes=5 -> 24/25 = 0.96.
295252
#[case::non_unit(&[2], &[3.0, 4.0], &[4.0, 3.0], &[0.96])]
296-
// Zero vector 0/0 NaN.
253+
// Zero vector -> 0/0 -> NaN.
297254
#[case::zero_norm(&[2], &[0.0, 0.0], &[1.0, 0.0], &[f64::NAN])]
298255
fn single_row(
299256
#[case] shape: &[usize],
@@ -332,14 +289,14 @@ mod tests {
332289
let lhs = tensor_array(&[], &[5.0, 3.0])?;
333290
let rhs = tensor_array(&[], &[5.0, -3.0])?;
334291

335-
// Same sign 1.0, opposite sign -1.0.
292+
// Same sign -> 1.0, opposite sign -> -1.0.
336293
assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, -1.0]);
337294
Ok(())
338295
}
339296

340297
#[test]
341298
fn many_rows() -> VortexResult<()> {
342-
// 5 tensors of shape [4] compared against themselves all 1.0.
299+
// 5 tensors of shape [4] compared against themselves -> all 1.0.
343300
let lhs = tensor_array(
344301
&[4],
345302
&[
@@ -359,35 +316,8 @@ mod tests {
359316
Ok(())
360317
}
361318

362-
/// Builds an extension array whose storage is a [`ConstantArray`], representing a single
363-
/// query tensor broadcast to `len` rows.
364-
fn constant_tensor_array(
365-
shape: &[usize],
366-
elements: &[f64],
367-
len: usize,
368-
) -> VortexResult<ArrayRef> {
369-
let element_dtype = DType::Primitive(vortex::dtype::PType::F64, Nullability::NonNullable);
370-
371-
// Build the FSL storage scalar from individual element scalars.
372-
let children: Vec<Scalar> = elements
373-
.iter()
374-
.map(|&v| Scalar::primitive(v, Nullability::NonNullable))
375-
.collect();
376-
let storage_scalar =
377-
Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable);
378-
379-
// Wrap the FSL scalar in a ConstantArray to avoid materializing `len` copies.
380-
let storage = ConstantArray::new(storage_scalar, len).into_array();
381-
382-
let metadata = FixedShapeTensorMetadata::new(shape.to_vec());
383-
let ext_dtype =
384-
ExtDType::<FixedShapeTensor>::try_new(metadata, storage.dtype().clone())?.erased();
385-
386-
Ok(ExtensionArray::new(ext_dtype, storage).into_array())
387-
}
388-
389319
#[test]
390-
fn constant_query_vector() -> VortexResult<()> {
320+
fn constant_query_tensor() -> VortexResult<()> {
391321
// Compare 4 tensors of shape [3] against a single constant query tensor [1,0,0].
392322
let data = tensor_array(
393323
&[3],
@@ -400,26 +330,13 @@ mod tests {
400330
)?;
401331
let query = constant_tensor_array(&[3], &[1.0, 0.0, 0.0], 4)?;
402332

403-
// Only tensor 0 is aligned with the query.
404333
assert_close(
405334
&eval_cosine_similarity(data, query, 4)?,
406335
&[1.0, 0.0, 0.0, 1.0],
407336
);
408337
Ok(())
409338
}
410339

411-
/// Builds a [`Vector`] extension array from flat f64 elements and a vector dimension size.
412-
fn vector_array(dim: u32, elements: &[f64]) -> VortexResult<ArrayRef> {
413-
let row_count = elements.len() / dim as usize;
414-
415-
let elems: ArrayRef = Buffer::copy_from(elements).into_array();
416-
let fsl = FixedSizeListArray::new(elems, dim, Validity::NonNullable, row_count);
417-
418-
let ext_dtype = ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();
419-
420-
Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array())
421-
}
422-
423340
#[test]
424341
fn vector_unit_vectors() -> VortexResult<()> {
425342
let lhs = vector_array(
@@ -442,43 +359,6 @@ mod tests {
442359
Ok(())
443360
}
444361

445-
#[test]
446-
fn vector_self_similarity() -> VortexResult<()> {
447-
let arr = vector_array(
448-
4,
449-
&[
450-
1.0, 2.0, 3.0, 4.0, // vector 0
451-
0.0, 1.0, 0.0, 0.0, // vector 1
452-
5.0, 0.0, 5.0, 0.0, // vector 2
453-
],
454-
)?;
455-
456-
assert_close(
457-
&eval_cosine_similarity(arr.clone(), arr, 3)?,
458-
&[1.0, 1.0, 1.0],
459-
);
460-
Ok(())
461-
}
462-
463-
/// Builds a [`Vector`] extension array whose storage is a [`ConstantArray`].
464-
fn constant_vector_array(elements: &[f64], len: usize) -> VortexResult<ArrayRef> {
465-
let element_dtype = DType::Primitive(vortex::dtype::PType::F64, Nullability::NonNullable);
466-
467-
let children: Vec<Scalar> = elements
468-
.iter()
469-
.map(|&v| Scalar::primitive(v, Nullability::NonNullable))
470-
.collect();
471-
let storage_scalar =
472-
Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable);
473-
474-
let storage = ConstantArray::new(storage_scalar, len).into_array();
475-
476-
let ext_dtype =
477-
ExtDType::<Vector>::try_new(EmptyMetadata, storage.dtype().clone())?.erased();
478-
479-
Ok(ExtensionArray::new(ext_dtype, storage).into_array())
480-
}
481-
482362
#[test]
483363
fn vector_constant_query() -> VortexResult<()> {
484364
let data = vector_array(

0 commit comments

Comments
 (0)