Skip to content

Commit 236c8f9

Browse files
authored
fix: filter out infinite vectors (#3648)
when training IVF/PQ, we sample the dataset and filter out nulls, but we still need to filter out NaN/INFs --------- Signed-off-by: BubbleCal <bubble-cal@outlook.com>
1 parent 9b0c274 commit 236c8f9

5 files changed

Lines changed: 95 additions & 18 deletions

File tree

rust/lance-index/src/vector/ivf.rs

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,8 @@ impl IvfTransformer {
118118
vector_column: &str,
119119
range: Option<Range<u32>>,
120120
) -> Self {
121-
let mut transforms: Vec<Arc<dyn Transformer>> = vec![
122-
Arc::new(KeepFiniteVectors::new(vector_column)),
123-
Arc::new(super::transform::Flatten::new(vector_column)),
124-
];
121+
let mut transforms: Vec<Arc<dyn Transformer>> =
122+
vec![Arc::new(super::transform::Flatten::new(vector_column))];
125123

126124
let dt = if distance_type == DistanceType::Cosine {
127125
transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
@@ -131,6 +129,7 @@ impl IvfTransformer {
131129
} else {
132130
distance_type
133131
};
132+
transforms.push(Arc::new(KeepFiniteVectors::new(vector_column)));
134133

135134
let ivf_transform = Arc::new(PartitionTransformer::new(
136135
centroids.clone(),
@@ -159,10 +158,8 @@ impl IvfTransformer {
159158
pq: ProductQuantizer,
160159
range: Option<Range<u32>>,
161160
) -> Self {
162-
let mut transforms: Vec<Arc<dyn Transformer>> = vec![
163-
Arc::new(KeepFiniteVectors::new(vector_column)),
164-
Arc::new(super::transform::Flatten::new(vector_column)),
165-
];
161+
let mut transforms: Vec<Arc<dyn Transformer>> =
162+
vec![Arc::new(super::transform::Flatten::new(vector_column))];
166163

167164
let distance_type = if distance_type == MetricType::Cosine {
168165
transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
@@ -172,6 +169,7 @@ impl IvfTransformer {
172169
} else {
173170
distance_type
174171
};
172+
transforms.push(Arc::new(KeepFiniteVectors::new(vector_column)));
175173

176174
let partition_transform = Arc::new(PartitionTransformer::new(
177175
centroids.clone(),
@@ -210,10 +208,8 @@ impl IvfTransformer {
210208
sq: ScalarQuantizer,
211209
range: Option<Range<u32>>,
212210
) -> Self {
213-
let mut transforms: Vec<Arc<dyn Transformer>> = vec![
214-
Arc::new(KeepFiniteVectors::new(vector_column)),
215-
Arc::new(super::transform::Flatten::new(vector_column)),
216-
];
211+
let mut transforms: Vec<Arc<dyn Transformer>> =
212+
vec![Arc::new(super::transform::Flatten::new(vector_column))];
217213

218214
let distance_type = if metric_type == MetricType::Cosine {
219215
transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
@@ -223,6 +219,7 @@ impl IvfTransformer {
223219
} else {
224220
metric_type
225221
};
222+
transforms.push(Arc::new(KeepFiniteVectors::new(vector_column)));
226223

227224
let partition_transformer = Arc::new(PartitionTransformer::new(
228225
centroids.clone(),

rust/lance-index/src/vector/utils.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use arrow::{
55
array::AsArray,
66
datatypes::{Float16Type, Float32Type, Float64Type},
77
};
8-
use arrow_array::{Array, FixedSizeListArray};
8+
use arrow_array::{Array, BooleanArray, FixedSizeListArray};
99
use arrow_schema::{DataType, Field};
1010
use lance_core::{Error, Result};
1111
use lance_io::encodings::plain::bytes_to_array;
@@ -164,6 +164,36 @@ impl TryFrom<&pb::Tensor> for FixedSizeListArray {
164164
}
165165
}
166166

167+
/// Check if all vectors in the FixedSizeListArray are finite
168+
/// null values are considered as not finite
169+
/// returns a BooleanArray
170+
/// with the same length as the FixedSizeListArray
171+
/// with true for finite values and false for non-finite values
172+
pub fn is_finite(fsl: &FixedSizeListArray) -> BooleanArray {
173+
let is_finite = fsl
174+
.iter()
175+
.map(|v| match v {
176+
Some(v) => match v.data_type() {
177+
DataType::Float16 => {
178+
let v = v.as_primitive::<Float16Type>();
179+
v.null_count() == 0 && v.values().iter().all(|v| v.is_finite())
180+
}
181+
DataType::Float32 => {
182+
let v = v.as_primitive::<Float32Type>();
183+
v.null_count() == 0 && v.values().iter().all(|v| v.is_finite())
184+
}
185+
DataType::Float64 => {
186+
let v = v.as_primitive::<Float64Type>();
187+
v.null_count() == 0 && v.values().iter().all(|v| v.is_finite())
188+
}
189+
_ => v.null_count() == 0,
190+
},
191+
None => false,
192+
})
193+
.collect::<Vec<_>>();
194+
BooleanArray::from(is_finite)
195+
}
196+
167197
#[cfg(test)]
168198
mod tests {
169199
use super::*;

rust/lance/src/index/vector/builder.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use lance_index::vector::quantizer::{
2424
QuantizationMetadata, QuantizationType, QuantizerBuildParams,
2525
};
2626
use lance_index::vector::storage::STORAGE_METADATA_KEY;
27+
use lance_index::vector::utils::is_finite;
2728
use lance_index::vector::v3::shuffler::IvfShufflerReader;
2829
use lance_index::vector::v3::subindex::SubIndexType;
2930
use lance_index::vector::{VectorIndex, LOSS_METADATA_KEY, PART_ID_COLUMN, PQ_CODE_COLUMN};
@@ -370,6 +371,10 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
370371
training_data
371372
};
372373

374+
// we filtered out nulls when sampling, but we still need to filter out NaNs and INFs here
375+
let training_data = arrow::compute::filter(&training_data, &is_finite(&training_data))?;
376+
let training_data = training_data.as_fixed_size_list();
377+
373378
let training_data = match (self.ivf.as_ref(), Q::use_residual(self.distance_type)) {
374379
(Some(ivf), true) => {
375380
let ivf_transformer = lance_index::vector::ivf::new_ivf_transformer(
@@ -378,9 +383,9 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
378383
vec![],
379384
);
380385
span!(Level::INFO, "compute residual for PQ training")
381-
.in_scope(|| ivf_transformer.compute_residual(&training_data))?
386+
.in_scope(|| ivf_transformer.compute_residual(training_data))?
382387
}
383-
_ => training_data,
388+
_ => training_data.clone(),
384389
};
385390

386391
info!("Start to train quantizer");

rust/lance/src/index/vector/ivf.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ use lance_index::vector::flat::index::{FlatBinQuantizer, FlatIndex, FlatQuantize
4646
use lance_index::vector::ivf::storage::IvfModel;
4747
use lance_index::vector::pq::storage::transpose;
4848
use lance_index::vector::quantizer::QuantizationType;
49+
use lance_index::vector::utils::is_finite;
4950
use lance_index::vector::v3::shuffler::IvfShuffler;
5051
use lance_index::vector::v3::subindex::{IvfSubIndex, SubIndexType};
5152
use lance_index::{
@@ -1222,9 +1223,13 @@ pub async fn build_ivf_model(
12221223
(training_data, metric_type)
12231224
};
12241225

1226+
// we filtered out nulls when sampling, but we still need to filter out NaNs and INFs here
1227+
let training_data = arrow::compute::filter(&training_data, &is_finite(&training_data))?;
1228+
let training_data = training_data.as_fixed_size_list();
1229+
12251230
info!("Start to train IVF model");
12261231
let start = std::time::Instant::now();
1227-
let ivf = train_ivf_model(centroids, &training_data, mt, params).await?;
1232+
let ivf = train_ivf_model(centroids, training_data, mt, params).await?;
12281233
info!(
12291234
"Trained IVF model in {:02} seconds",
12301235
start.elapsed().as_secs_f32()

rust/lance/src/index/vector/ivf/v2.rs

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -642,8 +642,8 @@ mod tests {
642642
use arrow::datatypes::{UInt64Type, UInt8Type};
643643
use arrow::{array::AsArray, datatypes::Float32Type};
644644
use arrow_array::{
645-
Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, FixedSizeListArray, ListArray,
646-
RecordBatch, RecordBatchIterator, UInt64Array,
645+
Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, FixedSizeListArray, Float32Array,
646+
ListArray, RecordBatch, RecordBatchIterator, UInt64Array,
647647
};
648648
use arrow_buffer::OffsetBuffer;
649649
use arrow_schema::{DataType, Field, Schema, SchemaRef};
@@ -1700,4 +1700,44 @@ mod tests {
17001700
assert_lt!(*d, dists[k - 1]);
17011701
});
17021702
}
1703+
1704+
#[tokio::test]
1705+
async fn test_index_with_zero_vectors() {
1706+
let test_dir = tempdir().unwrap();
1707+
let test_uri = test_dir.path().to_str().unwrap();
1708+
let (batch, schema) = generate_batch::<Float32Type>(256, None, 0.0..1.0, false);
1709+
let vector_field = schema.field(1).clone();
1710+
let zero_batch = RecordBatch::try_new(
1711+
schema.clone(),
1712+
vec![
1713+
Arc::new(UInt64Array::from(vec![256])),
1714+
Arc::new(
1715+
FixedSizeListArray::try_new_from_values(
1716+
Float32Array::from(vec![0.0; DIM]),
1717+
DIM as i32,
1718+
)
1719+
.unwrap(),
1720+
),
1721+
],
1722+
)
1723+
.unwrap();
1724+
let batches = RecordBatchIterator::new(vec![batch, zero_batch].into_iter().map(Ok), schema);
1725+
let mut dataset = Dataset::write(
1726+
batches,
1727+
test_uri,
1728+
Some(WriteParams {
1729+
mode: crate::dataset::WriteMode::Overwrite,
1730+
..Default::default()
1731+
}),
1732+
)
1733+
.await
1734+
.unwrap();
1735+
1736+
let vector_column = vector_field.name();
1737+
let params = VectorIndexParams::ivf_pq(4, 8, DIM / 8, DistanceType::Cosine, 50);
1738+
dataset
1739+
.create_index(&[vector_column], IndexType::Vector, None, &params, true)
1740+
.await
1741+
.unwrap();
1742+
}
17031743
}

0 commit comments

Comments
 (0)