File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -516,7 +516,7 @@ def _partition_assignment() -> Iterable[pa.RecordBatch]:
516516
517517 split_columns = []
518518 if num_sub_vectors is not None :
519- residual_vecs = vecs - kmeans .centroids [partitions ]
519+ residual_vecs = vecs [ mask_gpu ] - kmeans .centroids [partitions ]
520520 for i in range (num_sub_vectors ):
521521 subvector_tensor = residual_vecs [
522522 :, i * subvector_size : (i + 1 ) * subvector_size
@@ -685,7 +685,7 @@ def _partition_and_pq_codes_assignment() -> Iterable[pa.RecordBatch]:
685685 assert vecs .shape [0 ] == ids .shape [0 ]
686686
687687 # Ignore any invalid vectors.
688- mask_gpu = partitions .isfinite ()
688+ mask_gpu = partitions .isfinite () & ( partitions >= 0 )
689689 ids = ids .to (ivf_kmeans .device )[mask_gpu ].cpu ().reshape (- 1 )
690690 partitions = partitions [mask_gpu ].cpu ()
691691 vecs = vecs [mask_gpu ]
@@ -746,8 +746,6 @@ def _partition_and_pq_codes_assignment() -> Iterable[pa.RecordBatch]:
746746 LOGGER .info ("Saved precomputed pq_codes to %s" , dst_dataset_uri )
747747
748748 shuffle_buffers = [
749- data_file .path ()
750- for frag in ds .get_fragments ()
751- for data_file in frag .data_files ()
749+ data_file .path for frag in ds .get_fragments () for data_file in frag .data_files ()
752750 ]
753751 return dst_dataset_uri , shuffle_buffers
Original file line number Diff line number Diff line change @@ -254,6 +254,22 @@ def test_index_with_nans(tmp_path):
254254 validate_vector_index (dataset , "vector" )
255255
256256
257+ def test_torch_index_with_nans (tmp_path ):
258+ # 1024 rows, the entire table should be sampled
259+ tbl = create_table (nvec = 1000 , nans = 24 )
260+
261+ dataset = lance .write_dataset (tbl , tmp_path )
262+ dataset = dataset .create_index (
263+ "vector" ,
264+ index_type = "IVF_PQ" ,
265+ num_partitions = 4 ,
266+ num_sub_vectors = 16 ,
267+ accelerator = torch .device ("cpu" ),
268+ one_pass_ivfpq = True ,
269+ )
270+ validate_vector_index (dataset , "vector" )
271+
272+
257273def test_index_with_no_centroid_movement (tmp_path ):
258274 # this test makes the centroids essentially [1..]
259275 # this makes sure the early stop condition in the index building code
You can’t perform that action at this time.
0 commit comments