Skip to content

Commit 039c6c6

Browse files
fix: one more finite vector fix (#3673)
one more place that needed valid part id filter
1 parent 6d186d3 commit 039c6c6

2 files changed

Lines changed: 19 additions & 5 deletions

File tree

python/python/lance/vector.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ def _partition_assignment() -> Iterable[pa.RecordBatch]:
516516

517517
split_columns = []
518518
if num_sub_vectors is not None:
519-
residual_vecs = vecs - kmeans.centroids[partitions]
519+
residual_vecs = vecs[mask_gpu] - kmeans.centroids[partitions]
520520
for i in range(num_sub_vectors):
521521
subvector_tensor = residual_vecs[
522522
:, i * subvector_size : (i + 1) * subvector_size
@@ -685,7 +685,7 @@ def _partition_and_pq_codes_assignment() -> Iterable[pa.RecordBatch]:
685685
assert vecs.shape[0] == ids.shape[0]
686686

687687
# Ignore any invalid vectors.
688-
mask_gpu = partitions.isfinite()
688+
mask_gpu = partitions.isfinite() & (partitions >= 0)
689689
ids = ids.to(ivf_kmeans.device)[mask_gpu].cpu().reshape(-1)
690690
partitions = partitions[mask_gpu].cpu()
691691
vecs = vecs[mask_gpu]
@@ -746,8 +746,6 @@ def _partition_and_pq_codes_assignment() -> Iterable[pa.RecordBatch]:
746746
LOGGER.info("Saved precomputed pq_codes to %s", dst_dataset_uri)
747747

748748
shuffle_buffers = [
749-
data_file.path()
750-
for frag in ds.get_fragments()
751-
for data_file in frag.data_files()
749+
data_file.path for frag in ds.get_fragments() for data_file in frag.data_files()
752750
]
753751
return dst_dataset_uri, shuffle_buffers

python/python/tests/test_vector_index.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,22 @@ def test_index_with_nans(tmp_path):
254254
validate_vector_index(dataset, "vector")
255255

256256

257+
def test_torch_index_with_nans(tmp_path):
258+
# 1024 rows, the entire table should be sampled
259+
tbl = create_table(nvec=1000, nans=24)
260+
261+
dataset = lance.write_dataset(tbl, tmp_path)
262+
dataset = dataset.create_index(
263+
"vector",
264+
index_type="IVF_PQ",
265+
num_partitions=4,
266+
num_sub_vectors=16,
267+
accelerator=torch.device("cpu"),
268+
one_pass_ivfpq=True,
269+
)
270+
validate_vector_index(dataset, "vector")
271+
272+
257273
def test_index_with_no_centroid_movement(tmp_path):
258274
# this test makes the centroids essentially [1..]
259275
# this makes sure the early stop condition in the index building code

0 commit comments

Comments
 (0)