Skip to content

Commit 5678362

Browse files
committed
potential fix
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent b9fee0b commit 5678362

1 file changed

Lines changed: 15 additions & 3 deletions

File tree

  • vortex-tensor/src/encodings/norm

vortex-tensor/src/encodings/norm/array.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use num_traits::Float;
5+
use num_traits::Zero;
56
use vortex::array::ArrayRef;
67
use vortex::array::ExecutionCtx;
78
use vortex::array::IntoArray;
@@ -57,7 +58,8 @@ pub struct NormVectorArray {
5758
}
5859

5960
impl NormVectorArray {
60-
/// Creates a new [`NormVectorArray`] from a unit-normalized vector array and its L2 norms.
61+
/// Creates a new [`NormVectorArray`] from a unit-normalized vector array and associated L2
62+
/// norms for each vector.
6163
///
6264
/// The `vector_array` must be a [`Vector`] extension array with floating-point elements, and
6365
/// `norms` must be a primitive array of the same float type with the same length. The
@@ -113,12 +115,15 @@ impl NormVectorArray {
113115
/// The input must be a [`Vector`] extension array with floating-point elements. Nullable inputs
114116
/// are supported; the validity mask is preserved and the normalized data for null rows is
115117
/// unspecified.
118+
///
119+
/// Note that compression is lossy per floating-point operations.
116120
pub fn compress(vector_array: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<Self> {
117121
let ext = Self::validate(&vector_array)?;
118122

119123
let list_size = extension_list_size(&ext)?;
120124
let row_count = vector_array.len();
121125
let nullability = Nullability::from(vector_array.dtype().is_nullable());
126+
let validity = vector_array.validity()?;
122127

123128
// Compute L2 norms using the scalar function. If the input is nullable, the norms will
124129
// also be nullable (null vectors produce null norms).
@@ -135,10 +140,17 @@ impl NormVectorArray {
135140
let norms_slice = norms_prim.as_slice::<T>();
136141

137142
let normalized_elems: PrimitiveArray = (0..row_count)
138-
.flat_map(|i| {
143+
.map(|i| -> VortexResult<Vec<T>> {
144+
if !validity.is_valid(i)? {
145+
return Ok(vec![T::zero(); list_size]);
146+
}
147+
139148
let inv_norm = safe_inv_norm(norms_slice[i]);
140-
flat.row::<T>(i).iter().map(move |&v| v * inv_norm)
149+
Ok(flat.row::<T>(i).iter().map(|&v| v * inv_norm).collect())
141150
})
151+
.collect::<VortexResult<Vec<Vec<T>>>>()?
152+
.into_iter()
153+
.flatten()
142154
.collect();
143155

144156
// Reconstruct the vector array with the same nullability as the input.

0 commit comments

Comments
 (0)