Skip to content

Commit 79cb061

Browse files
committed
fix masking
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 62273ba commit 79cb061

1 file changed

Lines changed: 10 additions & 3 deletions

File tree

vortex-tensor/src/scalar_fns/cosine_similarity.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use vortex_array::scalar_fn::ExecutionArgs;
2424
use vortex_array::scalar_fn::ScalarFn;
2525
use vortex_array::scalar_fn::ScalarFnId;
2626
use vortex_array::scalar_fn::ScalarFnVTable;
27+
use vortex_array::validity::Validity;
2728
use vortex_buffer::Buffer;
2829
use vortex_error::VortexResult;
2930
use vortex_error::vortex_ensure;
@@ -222,9 +223,15 @@ impl CosineSimilarity {
222223
let (normalized_r, _) = extract_l2_denorm_children(rhs_ref);
223224

224225
// Dot product of already-normalized children IS the cosine similarity.
225-
InnerProduct::try_new_array(options, normalized_l, normalized_r, len)?
226-
.into_array()
227-
.mask(validity.to_array(len))
226+
let dot =
227+
InnerProduct::try_new_array(options, normalized_l, normalized_r, len)?.into_array();
228+
229+
if !matches!(validity, Validity::NonNullable) {
230+
// Masking always changes the nullability to nullable.
231+
dot.mask(validity.to_array(len))
232+
} else {
233+
Ok(dot)
234+
}
228235
}
229236

230237
/// One side is `L2Denorm`: `cosine_similarity = dot(n, b) / ||b||`.

0 commit comments

Comments
 (0)