Skip to content

Commit 5a198e8

Browse files
committed
clean up
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 1198952 commit 5a198e8

5 files changed

Lines changed: 173 additions & 241 deletions

File tree

vortex-array/src/scalar_fn/vtable.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync {
9292
Ok(args.to_vec())
9393
}
9494

95+
// TODO(connor): This needs a precondition for the number of args it has, or all implementations
96+
// need to return an error if it is wrong.
9597
/// Compute the return [`DType`] of the expression if evaluated over the given input types.
9698
fn return_dtype(&self, options: &Self::Options, args: &[DType]) -> VortexResult<DType>;
9799

vortex-tensor/src/scalar_fns/cosine_similarity.rs

Lines changed: 29 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use vortex::dtype::Nullability;
1919
use vortex::dtype::extension::Matcher;
2020
use vortex::error::VortexResult;
2121
use vortex::error::vortex_ensure;
22+
use vortex::error::vortex_ensure_eq;
2223
use vortex::error::vortex_err;
2324
use vortex::expr::Expression;
2425
use vortex::scalar_fn::Arity;
@@ -81,33 +82,38 @@ impl ScalarFnVTable for CosineSimilarity {
8182
}
8283

8384
fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
84-
debug_assert_eq!(arg_dtypes.len(), 2);
85+
vortex_ensure_eq!(
86+
arg_dtypes.len(),
87+
2,
88+
"CosineSimilarity requires exactly 2 arguments, got {}",
89+
arg_dtypes.len()
90+
);
8591

8692
let lhs = &arg_dtypes[0];
8793
let rhs = &arg_dtypes[1];
8894

8995
// Both must have the same dtype (ignoring top-level nullability).
9096
vortex_ensure!(
9197
lhs.eq_ignore_nullability(rhs),
92-
"cosine_similarity requires both inputs to have the same dtype, got {lhs} and {rhs}"
98+
"CosineSimilarity requires both inputs to have the same dtype, got {lhs} and {rhs}"
9399
);
94100

95101
// We don't need to look at rhs anymore since we know lhs and rhs are equal.
96102

97103
// Both inputs must be tensor-like extension types.
98104
let lhs_ext = lhs.as_extension_opt().ok_or_else(|| {
99-
vortex_err!("cosine_similarity lhs must be an extension type, got {lhs}")
105+
vortex_err!("CosineSimilarity lhs must be an extension type, got {lhs}")
100106
})?;
101107

102108
vortex_ensure!(
103109
AnyTensor::matches(lhs_ext),
104-
"cosine_similarity inputs must be an `AnyTensor`, got {lhs}"
110+
"CosineSimilarity inputs must be an `AnyTensor`, got {lhs}"
105111
);
106112

107113
let ptype = extension_element_ptype(lhs_ext)?;
108114
vortex_ensure!(
109115
ptype.is_float(),
110-
"cosine_similarity element dtype must be a float primitive, got {ptype}"
116+
"CosineSimilarity element dtype must be a float primitive, got {ptype}"
111117
);
112118

113119
let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable());
@@ -191,79 +197,32 @@ fn cosine_similarity_row<T: Float + NativePType>(a: &[T], b: &[T]) -> T {
191197

192198
#[cfg(test)]
193199
mod tests {
194-
use vortex::array::ArrayRef;
195-
use vortex::array::IntoArray;
200+
use rstest::rstest;
196201
use vortex::array::ToCanonical;
197-
use vortex::array::arrays::ConstantArray;
198-
use vortex::array::arrays::ExtensionArray;
199-
use vortex::array::arrays::FixedSizeListArray;
200202
use vortex::array::arrays::ScalarFnArray;
201-
use vortex::array::validity::Validity;
202-
use vortex::buffer::Buffer;
203-
use vortex::dtype::DType;
204-
use vortex::dtype::Nullability;
205-
use vortex::dtype::extension::ExtDType;
206203
use vortex::error::VortexResult;
207-
use vortex::extension::EmptyMetadata;
208-
use vortex::scalar::Scalar;
209204
use vortex::scalar_fn::EmptyOptions;
210205
use vortex::scalar_fn::ScalarFn;
211206

212-
use crate::fixed_shape::FixedShapeTensor;
213-
use crate::fixed_shape::FixedShapeTensorMetadata;
214207
use crate::scalar_fns::cosine_similarity::CosineSimilarity;
215-
use crate::vector::Vector;
216-
217-
/// Builds a [`FixedShapeTensor`] extension array from flat f64 elements and a logical shape.
218-
///
219-
/// The number of rows is inferred from the total element count divided by the product of the
220-
/// shape dimensions. For 0-dimensional tensors (scalar), each element is one row.
221-
fn tensor_array(shape: &[usize], elements: &[f64]) -> VortexResult<ArrayRef> {
222-
let list_size: u32 = shape.iter().product::<usize>().max(1).try_into().unwrap();
223-
let row_count = elements.len() / list_size as usize;
224-
225-
let elems: ArrayRef = Buffer::copy_from(elements).into_array();
226-
let fsl = FixedSizeListArray::new(elems, list_size, Validity::NonNullable, row_count);
227-
228-
let metadata = FixedShapeTensorMetadata::new(shape.to_vec());
229-
let ext_dtype =
230-
ExtDType::<FixedShapeTensor>::try_new(metadata, fsl.dtype().clone())?.erased();
231-
232-
Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array())
233-
}
208+
use crate::scalar_fns::utils::test_helpers::assert_close;
209+
use crate::scalar_fns::utils::test_helpers::constant_tensor_array;
210+
use crate::scalar_fns::utils::test_helpers::constant_vector_array;
211+
use crate::scalar_fns::utils::test_helpers::tensor_array;
212+
use crate::scalar_fns::utils::test_helpers::vector_array;
234213

235214
/// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec<f64>`.
236-
fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult<Vec<f64>> {
215+
fn eval_cosine_similarity(
216+
lhs: vortex::array::ArrayRef,
217+
rhs: vortex::array::ArrayRef,
218+
len: usize,
219+
) -> VortexResult<Vec<f64>> {
237220
let scalar_fn = ScalarFn::new(CosineSimilarity, EmptyOptions).erased();
238221
let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], len)?;
239222
let prim = result.to_primitive();
240223
Ok(prim.as_slice::<f64>().to_vec())
241224
}
242225

243-
/// Asserts that each element in `actual` is within `1e-10` of the corresponding `expected`
244-
/// value, with support for NaN (NaN == NaN is considered equal).
245-
#[track_caller]
246-
fn assert_close(actual: &[f64], expected: &[f64]) {
247-
assert_eq!(
248-
actual.len(),
249-
expected.len(),
250-
"length mismatch: got {} elements, expected {}",
251-
actual.len(),
252-
expected.len()
253-
);
254-
255-
for (i, (a, e)) in actual.iter().zip(expected).enumerate() {
256-
if a.is_nan() && e.is_nan() {
257-
continue;
258-
}
259-
assert!(
260-
(a - e).abs() < 1e-10,
261-
"element {i}: got {a}, expected {e} (diff = {})",
262-
(a - e).abs()
263-
);
264-
}
265-
}
266-
267226
#[test]
268227
fn unit_vectors_1d() -> VortexResult<()> {
269228
let lhs = tensor_array(
@@ -281,20 +240,18 @@ mod tests {
281240
],
282241
)?;
283242

284-
// Row 0: identical 1.0, row 1: orthogonal 0.0.
243+
// Row 0: identical -> 1.0, row 1: orthogonal -> 0.0.
285244
assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]);
286245
Ok(())
287246
}
288247

289-
use rstest::rstest;
290-
291248
/// Single-row cosine similarity for various vector pairs.
292249
#[rstest]
293-
// Antiparallel -1.0.
250+
// Antiparallel -> -1.0.
294251
#[case::opposite(&[3], &[1.0, 0.0, 0.0], &[-1.0, 0.0, 0.0], &[-1.0])]
295-
// dot=24, both magnitudes=5 24/25 = 0.96.
252+
// dot=24, both magnitudes=5 -> 24/25 = 0.96.
296253
#[case::non_unit(&[2], &[3.0, 4.0], &[4.0, 3.0], &[0.96])]
297-
// Zero vector 0/0 NaN.
254+
// Zero vector -> 0/0 -> NaN.
298255
#[case::zero_norm(&[2], &[0.0, 0.0], &[1.0, 0.0], &[f64::NAN])]
299256
fn single_row(
300257
#[case] shape: &[usize],
@@ -333,14 +290,14 @@ mod tests {
333290
let lhs = tensor_array(&[], &[5.0, 3.0])?;
334291
let rhs = tensor_array(&[], &[5.0, -3.0])?;
335292

336-
// Same sign 1.0, opposite sign -1.0.
293+
// Same sign -> 1.0, opposite sign -> -1.0.
337294
assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, -1.0]);
338295
Ok(())
339296
}
340297

341298
#[test]
342299
fn many_rows() -> VortexResult<()> {
343-
// 5 tensors of shape [4] compared against themselves all 1.0.
300+
// 5 tensors of shape [4] compared against themselves -> all 1.0.
344301
let lhs = tensor_array(
345302
&[4],
346303
&[
@@ -360,35 +317,8 @@ mod tests {
360317
Ok(())
361318
}
362319

363-
/// Builds an extension array whose storage is a [`ConstantArray`], representing a single
364-
/// query tensor broadcast to `len` rows.
365-
fn constant_tensor_array(
366-
shape: &[usize],
367-
elements: &[f64],
368-
len: usize,
369-
) -> VortexResult<ArrayRef> {
370-
let element_dtype = DType::Primitive(vortex::dtype::PType::F64, Nullability::NonNullable);
371-
372-
// Build the FSL storage scalar from individual element scalars.
373-
let children: Vec<Scalar> = elements
374-
.iter()
375-
.map(|&v| Scalar::primitive(v, Nullability::NonNullable))
376-
.collect();
377-
let storage_scalar =
378-
Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable);
379-
380-
// Wrap the FSL scalar in a ConstantArray to avoid materializing `len` copies.
381-
let storage = ConstantArray::new(storage_scalar, len).into_array();
382-
383-
let metadata = FixedShapeTensorMetadata::new(shape.to_vec());
384-
let ext_dtype =
385-
ExtDType::<FixedShapeTensor>::try_new(metadata, storage.dtype().clone())?.erased();
386-
387-
Ok(ExtensionArray::new(ext_dtype, storage).into_array())
388-
}
389-
390320
#[test]
391-
fn constant_query_vector() -> VortexResult<()> {
321+
fn constant_query_tensor() -> VortexResult<()> {
392322
// Compare 4 tensors of shape [3] against a single constant query tensor [1,0,0].
393323
let data = tensor_array(
394324
&[3],
@@ -401,26 +331,13 @@ mod tests {
401331
)?;
402332
let query = constant_tensor_array(&[3], &[1.0, 0.0, 0.0], 4)?;
403333

404-
// Only tensor 0 is aligned with the query.
405334
assert_close(
406335
&eval_cosine_similarity(data, query, 4)?,
407336
&[1.0, 0.0, 0.0, 1.0],
408337
);
409338
Ok(())
410339
}
411340

412-
/// Builds a [`Vector`] extension array from flat f64 elements and a vector dimension size.
413-
fn vector_array(dim: u32, elements: &[f64]) -> VortexResult<ArrayRef> {
414-
let row_count = elements.len() / dim as usize;
415-
416-
let elems: ArrayRef = Buffer::copy_from(elements).into_array();
417-
let fsl = FixedSizeListArray::new(elems, dim, Validity::NonNullable, row_count);
418-
419-
let ext_dtype = ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();
420-
421-
Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array())
422-
}
423-
424341
#[test]
425342
fn vector_unit_vectors() -> VortexResult<()> {
426343
let lhs = vector_array(
@@ -443,43 +360,6 @@ mod tests {
443360
Ok(())
444361
}
445362

446-
#[test]
447-
fn vector_self_similarity() -> VortexResult<()> {
448-
let arr = vector_array(
449-
4,
450-
&[
451-
1.0, 2.0, 3.0, 4.0, // vector 0
452-
0.0, 1.0, 0.0, 0.0, // vector 1
453-
5.0, 0.0, 5.0, 0.0, // vector 2
454-
],
455-
)?;
456-
457-
assert_close(
458-
&eval_cosine_similarity(arr.clone(), arr, 3)?,
459-
&[1.0, 1.0, 1.0],
460-
);
461-
Ok(())
462-
}
463-
464-
/// Builds a [`Vector`] extension array whose storage is a [`ConstantArray`].
465-
fn constant_vector_array(elements: &[f64], len: usize) -> VortexResult<ArrayRef> {
466-
let element_dtype = DType::Primitive(vortex::dtype::PType::F64, Nullability::NonNullable);
467-
468-
let children: Vec<Scalar> = elements
469-
.iter()
470-
.map(|&v| Scalar::primitive(v, Nullability::NonNullable))
471-
.collect();
472-
let storage_scalar =
473-
Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable);
474-
475-
let storage = ConstantArray::new(storage_scalar, len).into_array();
476-
477-
let ext_dtype =
478-
ExtDType::<Vector>::try_new(EmptyMetadata, storage.dtype().clone())?.erased();
479-
480-
Ok(ExtensionArray::new(ext_dtype, storage).into_array())
481-
}
482-
483363
#[test]
484364
fn vector_constant_query() -> VortexResult<()> {
485365
let data = vector_array(

0 commit comments

Comments
 (0)