Skip to content

Commit 39e0a0e

Browse files
committed
add most implementation
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 17e0e1d commit 39e0a0e

8 files changed

Lines changed: 145 additions & 45 deletions

File tree

vortex-tensor/src/encodings/norm/array.rs

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use vortex::array::ArrayRef;
5+
use vortex::array::ExecutionCtx;
6+
use vortex::dtype::DType;
7+
use vortex::dtype::Nullability;
8+
use vortex::error::VortexResult;
9+
use vortex::error::vortex_ensure;
10+
use vortex::error::vortex_ensure_eq;
11+
use vortex::error::vortex_err;
12+
13+
use crate::utils::extension_element_ptype;
14+
use crate::vector::Vector;
515

616
/// A normalized array that stores unit-normalized vectors alongside their original L2 norms.
717
///
@@ -12,15 +22,54 @@ pub struct NormVectorArray {
1222
/// The backing vector array that has been unit normalized.
1323
///
1424
/// The underlying elements of the vector array must be floating-point.
15-
vector_array: ArrayRef,
25+
pub(crate) vector_array: ArrayRef,
1626

1727
/// The L2 (Frobenius) norms of each vector.
1828
///
1929
/// This must have the same dtype as the elements of the vector array.
20-
norms: ArrayRef,
30+
pub(crate) norms: ArrayRef,
2131
}
2232

2333
impl NormVectorArray {
34+
/// Creates a new [`NormVectorArray`] from a unit-normalized vector array and its L2 norms.
35+
///
36+
/// The `vector_array` must be a [`Vector`] extension array with floating-point elements, and
37+
/// `norms` must be a primitive array of the same float type with the same length.
38+
pub fn try_new(vector_array: ArrayRef, norms: ArrayRef) -> VortexResult<Self> {
39+
let ext = vector_array.dtype().as_extension_opt().ok_or_else(|| {
40+
vortex_err!(
41+
"vector_array dtype must be an extension type, got {}",
42+
vector_array.dtype()
43+
)
44+
})?;
45+
46+
vortex_ensure!(
47+
ext.is::<Vector>(),
48+
"vector_array must have the Vector extension type, got {}",
49+
vector_array.dtype()
50+
);
51+
52+
let element_ptype = extension_element_ptype(ext)?;
53+
54+
let expected_norms_dtype = DType::Primitive(element_ptype, Nullability::NonNullable);
55+
vortex_ensure_eq!(
56+
*norms.dtype(),
57+
expected_norms_dtype,
58+
"norms dtype must match vector element type"
59+
);
60+
61+
vortex_ensure_eq!(
62+
vector_array.len(),
63+
norms.len(),
64+
"vector_array and norms must have the same length"
65+
);
66+
67+
Ok(Self {
68+
vector_array,
69+
norms,
70+
})
71+
}
72+
2473
/// Returns a reference to the backing vector array that has been unit normalized.
2574
pub fn vector_array(&self) -> &ArrayRef {
2675
&self.vector_array
@@ -30,4 +79,9 @@ impl NormVectorArray {
3079
pub fn norms(&self) -> &ArrayRef {
3180
&self.norms
3281
}
82+
83+
// TODO docs
84+
pub(super) fn execute_into_vector(&self, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
85+
todo!()
86+
}
3387
}

vortex-tensor/src/encodings/norm/vtable/mod.rs

Lines changed: 70 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
use std::hash::Hasher;
55

6+
use vortex::array::ArrayEq;
7+
use vortex::array::ArrayHash;
68
use vortex::array::ArrayRef;
79
use vortex::array::EmptyMetadata;
810
use vortex::array::ExecutionCtx;
@@ -16,10 +18,15 @@ use vortex::array::vtable::ArrayId;
1618
use vortex::array::vtable::VTable;
1719
use vortex::array::vtable::ValidityVTableFromChild;
1820
use vortex::dtype::DType;
21+
use vortex::dtype::Nullability;
1922
use vortex::error::VortexResult;
23+
use vortex::error::vortex_ensure_eq;
24+
use vortex::error::vortex_err;
25+
use vortex::error::vortex_panic;
2026
use vortex::session::VortexSession;
2127

2228
use crate::encodings::norm::array::NormVectorArray;
29+
use crate::utils::extension_element_ptype;
2330

2431
mod operations;
2532
mod validity;
@@ -52,70 +59,109 @@ impl VTable for NormVector {
5259
}
5360

5461
fn array_hash<H: Hasher>(array: &NormVectorArray, state: &mut H, precision: Precision) {
55-
todo!()
62+
array.vector_array().array_hash(state, precision);
63+
array.norms().array_hash(state, precision);
5664
}
5765

5866
fn array_eq(array: &NormVectorArray, other: &NormVectorArray, precision: Precision) -> bool {
59-
todo!()
67+
array.norms().array_eq(other.norms(), precision)
68+
&& array
69+
.vector_array()
70+
.array_eq(other.vector_array(), precision)
6071
}
6172

62-
fn nbuffers(array: &NormVectorArray) -> usize {
63-
todo!()
73+
fn nbuffers(_array: &NormVectorArray) -> usize {
74+
0
6475
}
6576

66-
fn buffer(array: &NormVectorArray, idx: usize) -> BufferHandle {
67-
todo!()
77+
fn buffer(_array: &NormVectorArray, idx: usize) -> BufferHandle {
78+
vortex_panic!("NormVectorArray has no buffers (index {idx})")
6879
}
6980

70-
fn buffer_name(array: &NormVectorArray, idx: usize) -> Option<String> {
71-
todo!()
81+
fn buffer_name(_array: &NormVectorArray, idx: usize) -> Option<String> {
82+
vortex_panic!("NormVectorArray has no buffers (index {idx})")
7283
}
7384

74-
fn nchildren(array: &NormVectorArray) -> usize {
75-
todo!()
85+
fn nchildren(_array: &NormVectorArray) -> usize {
86+
2
7687
}
7788

7889
fn child(array: &NormVectorArray, idx: usize) -> ArrayRef {
79-
todo!()
90+
match idx {
91+
0 => array.vector_array().clone(),
92+
1 => array.norms().clone(),
93+
_ => vortex_panic!("NormVectorArray child index {idx} out of bounds"),
94+
}
8095
}
8196

82-
fn child_name(array: &NormVectorArray, idx: usize) -> String {
83-
todo!()
97+
fn child_name(_array: &NormVectorArray, idx: usize) -> String {
98+
match idx {
99+
0 => "vector_array".to_string(),
100+
1 => "norms".to_string(),
101+
_ => vortex_panic!("NormVectorArray child_name index {idx} out of bounds"),
102+
}
84103
}
85104

86-
fn metadata(array: &NormVectorArray) -> VortexResult<Self::Metadata> {
87-
todo!()
105+
fn metadata(_array: &NormVectorArray) -> VortexResult<Self::Metadata> {
106+
Ok(EmptyMetadata)
88107
}
89108

90-
fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
91-
todo!()
109+
fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
110+
Ok(Some(vec![]))
92111
}
93112

94113
fn deserialize(
95-
bytes: &[u8],
114+
_bytes: &[u8],
96115
_dtype: &DType,
97116
_len: usize,
98117
_buffers: &[BufferHandle],
99118
_session: &VortexSession,
100119
) -> VortexResult<Self::Metadata> {
101-
todo!()
120+
Ok(EmptyMetadata)
102121
}
103122

104123
fn build(
105124
dtype: &DType,
106125
len: usize,
107-
metadata: &Self::Metadata,
108-
buffers: &[BufferHandle],
126+
_metadata: &Self::Metadata,
127+
_buffers: &[BufferHandle],
109128
children: &dyn ArrayChildren,
110129
) -> VortexResult<NormVectorArray> {
111-
todo!()
130+
vortex_ensure_eq!(
131+
children.len(),
132+
2,
133+
"NormVectorArray requires exactly 2 children"
134+
);
135+
136+
let vector_array = children.get(0, dtype, len)?;
137+
138+
let ext = dtype.as_extension_opt().ok_or_else(|| {
139+
vortex_err!("NormVectorArray dtype must be an extension type, got {dtype}")
140+
})?;
141+
let element_ptype = extension_element_ptype(ext)?;
142+
let norms_dtype = DType::Primitive(element_ptype, Nullability::NonNullable);
143+
let norms = children.get(1, &norms_dtype, len)?;
144+
145+
NormVectorArray::try_new(vector_array, norms)
112146
}
113147

114148
fn with_children(array: &mut NormVectorArray, children: Vec<ArrayRef>) -> VortexResult<()> {
115-
todo!()
149+
vortex_ensure_eq!(
150+
children.len(),
151+
2,
152+
"NormVectorArray requires exactly 2 children"
153+
);
154+
155+
let [vector_array, norms]: [ArrayRef; 2] = children
156+
.try_into()
157+
.map_err(|_| vortex_err!("NormVectorArray requires exactly 2 children"))?;
158+
159+
array.vector_array = vector_array;
160+
array.norms = norms;
161+
Ok(())
116162
}
117163

118164
fn execute(array: &NormVectorArray, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionStep> {
119-
todo!()
165+
Ok(ExecutionStep::Done(array.execute_into_vector(ctx)?))
120166
}
121167
}

vortex-tensor/src/encodings/norm/vtable/operations.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ use crate::encodings::norm::vtable::NormVector;
1010

1111
impl OperationsVTable<NormVector> for NormVector {
1212
fn scalar_at(array: &NormVectorArray, index: usize) -> VortexResult<Scalar> {
13-
todo!()
13+
array.vector_array().scalar_at(index)
1414
}
1515
}

vortex-tensor/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@ pub mod fixed_shape;
1212
pub mod vector;
1313

1414
pub mod encodings;
15+
16+
mod utils;

vortex-tensor/src/scalar_fns/cosine_similarity.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ use vortex::scalar_fn::ScalarFnId;
2828
use vortex::scalar_fn::ScalarFnVTable;
2929

3030
use crate::matcher::AnyTensor;
31-
use crate::scalar_fns::utils::extension_element_ptype;
32-
use crate::scalar_fns::utils::extension_list_size;
33-
use crate::scalar_fns::utils::extension_storage;
34-
use crate::scalar_fns::utils::extract_flat_elements;
31+
use crate::utils::extension_element_ptype;
32+
use crate::utils::extension_list_size;
33+
use crate::utils::extension_storage;
34+
use crate::utils::extract_flat_elements;
3535

3636
/// Cosine similarity between two columns.
3737
///
@@ -196,11 +196,11 @@ mod tests {
196196
use vortex::scalar_fn::ScalarFn;
197197

198198
use crate::scalar_fns::cosine_similarity::CosineSimilarity;
199-
use crate::scalar_fns::utils::test_helpers::assert_close;
200-
use crate::scalar_fns::utils::test_helpers::constant_tensor_array;
201-
use crate::scalar_fns::utils::test_helpers::constant_vector_array;
202-
use crate::scalar_fns::utils::test_helpers::tensor_array;
203-
use crate::scalar_fns::utils::test_helpers::vector_array;
199+
use crate::utils::test_helpers::assert_close;
200+
use crate::utils::test_helpers::constant_tensor_array;
201+
use crate::utils::test_helpers::constant_vector_array;
202+
use crate::utils::test_helpers::tensor_array;
203+
use crate::utils::test_helpers::vector_array;
204204

205205
/// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec<f64>`.
206206
fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult<Vec<f64>> {

vortex-tensor/src/scalar_fns/l2_norm.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ use vortex::scalar_fn::ScalarFnId;
2828
use vortex::scalar_fn::ScalarFnVTable;
2929

3030
use crate::matcher::AnyTensor;
31-
use crate::scalar_fns::utils::extension_element_ptype;
32-
use crate::scalar_fns::utils::extension_list_size;
33-
use crate::scalar_fns::utils::extension_storage;
34-
use crate::scalar_fns::utils::extract_flat_elements;
31+
use crate::utils::extension_element_ptype;
32+
use crate::utils::extension_list_size;
33+
use crate::utils::extension_storage;
34+
use crate::utils::extract_flat_elements;
3535

3636
/// L2 norm (Euclidean norm) of a tensor or vector column.
3737
///
@@ -163,9 +163,9 @@ mod tests {
163163
use vortex::scalar_fn::ScalarFn;
164164

165165
use crate::scalar_fns::l2_norm::L2Norm;
166-
use crate::scalar_fns::utils::test_helpers::assert_close;
167-
use crate::scalar_fns::utils::test_helpers::tensor_array;
168-
use crate::scalar_fns::utils::test_helpers::vector_array;
166+
use crate::utils::test_helpers::assert_close;
167+
use crate::utils::test_helpers::tensor_array;
168+
use crate::utils::test_helpers::vector_array;
169169

170170
/// Evaluates L2 norm on a tensor/vector array and returns the result as `Vec<f64>`.
171171
fn eval_l2_norm(input: vortex::array::ArrayRef, len: usize) -> VortexResult<Vec<f64>> {

vortex-tensor/src/scalar_fns/mod.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,3 @@
55
66
pub mod cosine_similarity;
77
pub mod l2_norm;
8-
9-
mod utils;

0 commit comments

Comments
 (0)