Skip to content

Commit e288bb1

Browse files
committed
use broadcast with runend sequence and binary
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 5678362 commit e288bb1

5 files changed

Lines changed: 115 additions & 53 deletions

File tree

vortex-tensor/src/encodings/norm/array.rs

Lines changed: 109 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,33 @@ use num_traits::Zero;
66
use vortex::array::ArrayRef;
77
use vortex::array::ExecutionCtx;
88
use vortex::array::IntoArray;
9+
use vortex::array::LEGACY_SESSION;
10+
use vortex::array::VortexSessionExecute;
911
use vortex::array::arrays::ExtensionArray;
1012
use vortex::array::arrays::FixedSizeListArray;
1113
use vortex::array::arrays::PrimitiveArray;
14+
use vortex::array::builtins::ArrayBuiltins;
1215
use vortex::array::match_each_float_ptype;
1316
use vortex::array::stats::ArrayStats;
1417
use vortex::array::validity::Validity;
1518
use vortex::dtype::DType;
1619
use vortex::dtype::Nullability;
1720
use vortex::dtype::extension::ExtDType;
1821
use vortex::dtype::extension::ExtDTypeRef;
22+
use vortex::encodings::runend::RunEndArray;
23+
use vortex::encodings::sequence::SequenceArray;
24+
use vortex::error::VortexExpect;
1925
use vortex::error::VortexResult;
2026
use vortex::error::vortex_ensure;
2127
use vortex::error::vortex_ensure_eq;
2228
use vortex::error::vortex_err;
2329
use vortex::expr::Expression;
2430
use vortex::expr::root;
2531
use vortex::extension::EmptyMetadata;
32+
use vortex::scalar::PValue;
2633
use vortex::scalar_fn::EmptyOptions;
2734
use vortex::scalar_fn::ScalarFn;
35+
use vortex::scalar_fn::fns::operators::Operator;
2836

2937
use crate::scalar_fns::l2_norm::L2Norm;
3038
use crate::utils::extension_element_ptype;
@@ -45,12 +53,13 @@ pub struct NormVectorArray {
4553
/// The backing vector array that has been unit normalized.
4654
///
4755
/// The underlying elements of the vector array must be floating-point. This child may be
48-
/// nullable; its validity determines the validity of the `NormVectorArray`.
56+
/// nullable, and its validity determines the validity of the `NormVectorArray`.
4957
pub(crate) vector_array: ArrayRef,
5058

5159
/// The L2 norms of each vector.
5260
///
53-
/// This must have the same dtype as the elements of the vector array.
61+
/// This must have the same validity as the vector array, and the same dtype as the elements of
62+
/// the vector array.
5463
pub(crate) norms: ArrayRef,
5564

5665
/// Stats set owned by this array.
@@ -65,7 +74,7 @@ impl NormVectorArray {
6574
/// `norms` must be a primitive array of the same float type with the same length. The
6675
/// `vector_array` may be nullable.
6776
pub fn try_new(vector_array: ArrayRef, norms: ArrayRef) -> VortexResult<Self> {
68-
let ext = Self::validate(&vector_array)?;
77+
let ext = Self::validate(&vector_array, &norms)?;
6978

7079
let element_ptype = extension_element_ptype(&ext)?;
7180

@@ -90,9 +99,9 @@ impl NormVectorArray {
9099
})
91100
}
92101

93-
/// Validates that the given array has the [`Vector`] extension type and returns the extension
94-
/// dtype.
95-
fn validate(vector_array: &ArrayRef) -> VortexResult<ExtDTypeRef> {
102+
/// Validates that the given array has the [`Vector`] extension type and returns the
103+
/// [`ExtDTypeRef`] of the vector array on success.
104+
fn validate_vector_array(vector_array: &ArrayRef) -> VortexResult<ExtDTypeRef> {
96105
let ext = vector_array.dtype().as_extension_opt().ok_or_else(|| {
97106
vortex_err!(
98107
"vector_array dtype must be an extension type, got {}",
@@ -109,6 +118,54 @@ impl NormVectorArray {
109118
Ok(ext.clone())
110119
}
111120

121+
/// Validates that the given `vector_array` and `norms` are compatible.
122+
///
123+
/// Checks that:
124+
/// - The `vector_array` has the [`Vector`] extension type.
125+
/// - Both arrays have the same length.
126+
/// - The element primitive type of the vectors matches the primitive type of the norms.
127+
/// - Both arrays share the same validity mask.
128+
///
129+
/// Returns the [`ExtDTypeRef`] of the vector array on success.
130+
fn validate(vector_array: &ArrayRef, norms: &ArrayRef) -> VortexResult<ExtDTypeRef> {
131+
let ext = Self::validate_vector_array(vector_array)?;
132+
133+
vortex_ensure_eq!(
134+
vector_array.len(),
135+
norms.len(),
136+
"vector_array and norms must have the same length"
137+
);
138+
139+
let element_ptype = extension_element_ptype(&ext)?;
140+
vortex_ensure_eq!(
141+
element_ptype,
142+
norms.dtype().as_ptype(),
143+
"vector elements ptype must be the same as the norms ptype"
144+
);
145+
146+
// TODO(connor): Is there a better way to do this?
147+
let mut ctx = LEGACY_SESSION.create_execution_ctx();
148+
let mask_eq = vector_array
149+
.validity()?
150+
.mask_eq(&norms.validity()?, &mut ctx)?;
151+
vortex_ensure!(
152+
mask_eq,
153+
"vector_array and norms must have the same validity"
154+
);
155+
156+
Ok(ext)
157+
}
158+
159+
/// Returns a reference to the backing vector array that has been unit normalized.
160+
pub fn vector_array(&self) -> &ArrayRef {
161+
&self.vector_array
162+
}
163+
164+
/// Returns a reference to the L2 norms of each vector.
165+
pub fn norms(&self) -> &ArrayRef {
166+
&self.norms
167+
}
168+
112169
/// Encodes a [`Vector`] extension array into a [`NormVectorArray`] by computing L2 norms and
113170
/// dividing each vector by its norm.
114171
///
@@ -118,9 +175,9 @@ impl NormVectorArray {
118175
///
119176
/// Note that compression is lossy per floating-point operations.
120177
pub fn compress(vector_array: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<Self> {
121-
let ext = Self::validate(&vector_array)?;
178+
let ext = Self::validate_vector_array(&vector_array)?;
122179

123-
let list_size = extension_list_size(&ext)?;
180+
let list_size = extension_list_size(&ext)? as usize;
124181
let row_count = vector_array.len();
125182
let nullability = Nullability::from(vector_array.dtype().is_nullable());
126183
let validity = vector_array.validity()?;
@@ -170,57 +227,62 @@ impl NormVectorArray {
170227
})
171228
}
172229

173-
/// Returns a reference to the backing vector array that has been unit normalized.
174-
pub fn vector_array(&self) -> &ArrayRef {
175-
&self.vector_array
176-
}
177-
178-
/// Returns a reference to the L2 norms of each vector.
179-
pub fn norms(&self) -> &ArrayRef {
180-
&self.norms
181-
}
182-
183230
/// Reconstructs the original vectors by multiplying each unit-normalized vector by its L2 norm.
184231
///
185232
/// The returned array has the same dtype (including nullability) as the original
186233
/// `vector_array` child.
187234
pub fn decompress(&self, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
188-
let ext = Self::validate(&self.vector_array)?;
189-
let nullability = Nullability::from(self.vector_array.dtype().is_nullable());
190-
191-
let list_size = extension_list_size(&ext)?;
192-
let row_count = self.vector_array.len();
235+
let ext = self
236+
.dtype()
237+
.as_extension_opt()
238+
.vortex_expect("somehow had a non-extension dtype");
193239

194240
let storage = extension_storage(&self.vector_array)?;
195-
let flat = extract_flat_elements(&storage, list_size, ctx)?;
241+
let fsl: FixedSizeListArray = storage.execute(ctx)?;
196242

197-
let norms_prim: PrimitiveArray = self.norms.clone().execute(ctx)?;
243+
let denormalized_fsl =
244+
broadcast_binary_to_elements(fsl, self.norms.clone(), Operator::Mul, ctx)?;
198245

199-
match_each_float_ptype!(flat.ptype(), |T| {
200-
let norms_slice = norms_prim.as_slice::<T>();
201-
202-
let result_elems: PrimitiveArray = (0..row_count)
203-
.flat_map(|i| {
204-
let norm = norms_slice[i];
205-
flat.row::<T>(i).iter().map(move |&v| v * norm)
206-
})
207-
.collect();
208-
209-
let validity = Validity::from(nullability);
210-
let fsl = FixedSizeListArray::new(
211-
result_elems.into_array(),
212-
u32::try_from(list_size)?,
213-
validity,
214-
row_count,
215-
);
216-
217-
let ext_dtype =
218-
ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();
219-
Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array())
220-
})
246+
Ok(ExtensionArray::new(ext.clone(), denormalized_fsl.into_array()).into_array())
221247
}
222248
}
223249

250+
/// We do not have any kind of "broadcast" expression where we evaluate a binary expression between
251+
/// every `FixedSizeList` element and another value. We can mimic this by creating a
252+
/// `RunEnd(Sequence)` array that we evaluate with the elements of the [`FixedSizeListArray`].
253+
fn broadcast_binary_to_elements(
254+
fsl: FixedSizeListArray,
255+
values: ArrayRef,
256+
op: Operator,
257+
ctx: &mut ExecutionCtx,
258+
) -> VortexResult<FixedSizeListArray> {
259+
let num_lists = fsl.len();
260+
let list_size = fsl.list_size();
261+
let validity = fsl.validity()?;
262+
let elements = fsl.elements();
263+
debug_assert!(elements.dtype().is_primitive());
264+
265+
// Create the broadcasting array via a runend array with a sequence of ends.
266+
let base: PValue = list_size.into();
267+
let multiplier: PValue = base;
268+
let ends_ptype = base.ptype();
269+
let ends_nullability = Nullability::NonNullable;
270+
271+
let ends = SequenceArray::try_new(base, multiplier, ends_ptype, ends_nullability, num_lists)?;
272+
let runend = RunEndArray::try_new(ends.into_array(), values)?;
273+
274+
let binary_eval = elements.binary(runend.into_array(), op)?;
275+
let executed: PrimitiveArray = binary_eval.execute(ctx)?;
276+
277+
// SAFETY: We simply evaluated a scalar function on all of the elements, so none of the length
278+
// properties have changed.
279+
let fsl = unsafe {
280+
FixedSizeListArray::new_unchecked(executed.into_array(), list_size, validity, num_lists)
281+
};
282+
283+
Ok(fsl)
284+
}
285+
224286
/// Returns `1 / norm` if the norm is non-zero, or zero otherwise.
225287
///
226288
/// This avoids division by zero for zero-length or all-zero vectors.

vortex-tensor/src/encodings/norm/vtable/operations.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ impl OperationsVTable<NormVector> for NormVector {
2929
array.vector_array().dtype()
3030
)
3131
})?;
32-
let list_size = extension_list_size(ext)?;
32+
let list_size = extension_list_size(ext)? as usize;
3333

3434
// Get the storage (FixedSizeList) and slice out the elements for this row.
3535
let storage = extension_storage(array.vector_array())?;

vortex-tensor/src/scalar_fns/cosine_similarity.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ impl ScalarFnVTable for CosineSimilarity {
128128
lhs.dtype()
129129
)
130130
})?;
131-
let list_size = extension_list_size(ext)?;
131+
let list_size = extension_list_size(ext)? as usize;
132132

133133
// Extract the storage array from each extension input. We pass the storage (FSL) rather
134134
// than the extension array to avoid canonicalizing the extension wrapper.

vortex-tensor/src/scalar_fns/l2_norm.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ impl ScalarFnVTable for L2Norm {
110110
input.dtype()
111111
)
112112
})?;
113-
let list_size = extension_list_size(ext)?;
113+
let list_size = extension_list_size(ext)? as usize;
114114

115115
let storage = extension_storage(&input)?;
116116
let flat = extract_flat_elements(&storage, list_size, ctx)?;

vortex-tensor/src/utils.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ use vortex::error::vortex_err;
2121
/// Extracts the list size from a tensor-like extension dtype.
2222
///
2323
/// The storage dtype must be a `FixedSizeList`.
24-
pub fn extension_list_size(ext: &ExtDTypeRef) -> VortexResult<usize> {
24+
pub fn extension_list_size(ext: &ExtDTypeRef) -> VortexResult<u32> {
2525
let DType::FixedSizeList(_, list_size, _) = ext.storage_dtype() else {
2626
vortex_bail!(
2727
"expected FixedSizeList storage dtype, got {}",
2828
ext.storage_dtype()
2929
);
3030
};
3131

32-
Ok(*list_size as usize)
32+
Ok(*list_size)
3333
}
3434

3535
/// Extracts the float element [`PType`] from a tensor-like extension dtype.
@@ -232,7 +232,7 @@ pub mod test_helpers {
232232
.dtype()
233233
.as_extension_opt()
234234
.ok_or_else(|| vortex_err!("expected Vector extension dtype, got {}", array.dtype()))?;
235-
let list_size = extension_list_size(ext)?;
235+
let list_size = extension_list_size(ext)? as usize;
236236
let storage = extension_storage(array)?;
237237
let flat = extract_flat_elements(&storage, list_size, ctx)?;
238238
Ok((0..array.len())

0 commit comments

Comments
 (0)