Skip to content

Commit 6028e79

Browse files
authored
fix(udf): validate query vector dimensions before computing distances (#15)
Distance UDFs (l2_distance, cosine_distance, negative_dot_product) used zip() to iterate over column and query vectors, silently truncating to the shorter length on dimension mismatch. This returned wrong results instead of an error — especially dangerous on the brute-force fallback path (SELECT *) where USearch dimension validation is bypassed. For FixedSizeList: single upfront check (zero per-row cost). For List/LargeList: per-row check (one branch, negligible vs kernel).
1 parent b0fdce3 commit 6028e79

2 files changed

Lines changed: 361 additions & 3 deletions

File tree

src/udf.rs

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,23 @@ fn inner_to_f32(inner: &dyn Array, udf_name: &str) -> Result<Vec<f32>> {
139139
///
140140
/// Supports all outer array types (FixedSizeList, List, LargeList) and
141141
/// inner element types (Float32, Float64 — Float64 is cast to f32 for the kernel).
142+
///
143+
/// Returns an error if vector dimensions do not match the query length.
142144
fn compute_distances(
143145
vec_col: &dyn Array,
144146
query_vec: &[f32],
145147
kernel: Kernel,
146148
udf_name: &str,
147149
) -> Result<Vec<Option<f32>>> {
148-
// FixedSizeListArray — typical for DuckDB FLOAT[N] or pre-cast columns
150+
// FixedSizeListArray — dimension known from type, validate once upfront.
149151
if let Some(fsl) = vec_col.as_any().downcast_ref::<FixedSizeListArray>() {
152+
let col_dim = fsl.value_length() as usize;
153+
if col_dim != query_vec.len() {
154+
return Err(DataFusionError::Execution(format!(
155+
"{udf_name}: query vector length ({}) must match column dimensionality ({col_dim})",
156+
query_vec.len(),
157+
)));
158+
}
150159
let mut out = Vec::with_capacity(fsl.len());
151160
for i in 0..fsl.len() {
152161
if fsl.is_null(i) {
@@ -159,7 +168,7 @@ fn compute_distances(
159168
return Ok(out);
160169
}
161170

162-
// ListArray — variable-length, e.g. PostgreSQL real[] / float8[]
171+
// ListArray — variable-length, validate per row.
163172
if let Some(lst) = vec_col.as_any().downcast_ref::<ListArray>() {
164173
let mut out = Vec::with_capacity(lst.len());
165174
for i in 0..lst.len() {
@@ -168,12 +177,19 @@ fn compute_distances(
168177
continue;
169178
}
170179
let f32s = inner_to_f32(&*lst.value(i), udf_name)?;
180+
if f32s.len() != query_vec.len() {
181+
return Err(DataFusionError::Execution(format!(
182+
"{udf_name}: query vector length ({}) must match row {i} dimensionality ({})",
183+
query_vec.len(),
184+
f32s.len(),
185+
)));
186+
}
171187
out.push(Some(kernel(&f32s, query_vec)));
172188
}
173189
return Ok(out);
174190
}
175191

176-
// LargeListArray — large-offset variant, e.g. some Postgres/Parquet encodings
192+
// LargeListArray — large-offset variant, validate per row.
177193
if let Some(lst) = vec_col.as_any().downcast_ref::<LargeListArray>() {
178194
let mut out = Vec::with_capacity(lst.len());
179195
for i in 0..lst.len() {
@@ -182,6 +198,13 @@ fn compute_distances(
182198
continue;
183199
}
184200
let f32s = inner_to_f32(&*lst.value(i), udf_name)?;
201+
if f32s.len() != query_vec.len() {
202+
return Err(DataFusionError::Execution(format!(
203+
"{udf_name}: query vector length ({}) must match row {i} dimensionality ({})",
204+
query_vec.len(),
205+
f32s.len(),
206+
)));
207+
}
185208
out.push(Some(kernel(&f32s, query_vec)));
186209
}
187210
return Ok(out);

0 commit comments

Comments
 (0)