@@ -215,11 +215,11 @@ def test_batch_meta_chunk_by_partition(self):
215215 name = "test_field" , dtype = torch .float32 , shape = (2 ,), production_status = ProductionStatus .READY_FOR_CONSUME
216216 )
217217 }
218- samples = [SampleMeta (partition_id = f"partition_{ i % 4 } " , global_index = i , fields = fields ) for i in range (10 )]
218+ samples = [SampleMeta (partition_id = f"partition_{ i % 4 } " , global_index = i + 10 , fields = fields ) for i in range (10 )]
219219 batch = BatchMeta (
220220 samples = samples ,
221- custom_meta = {i : {"uid" : i } for i in range (10 )},
222- _custom_backend_meta = {i : {"test_field" : {"dtype" : torch .float32 }} for i in range (10 )},
221+ custom_meta = {i + 10 : {"uid" : i + 10 } for i in range (10 )},
222+ _custom_backend_meta = {i + 10 : {"test_field" : {"dtype" : torch .float32 }} for i in range (10 )},
223223 )
224224
225225 # Chunk according to partition_id
@@ -228,30 +228,30 @@ def test_batch_meta_chunk_by_partition(self):
228228 assert len (chunks ) == 4
229229 assert len (chunks [0 ]) == 3
230230 assert chunks [0 ].partition_ids == ["partition_0" , "partition_0" , "partition_0" ]
231- assert chunks [0 ].global_indexes == [0 , 4 , 8 ]
231+ assert chunks [0 ].global_indexes == [10 , 14 , 18 ]
232232 assert len (chunks [1 ]) == 3
233233 assert chunks [1 ].partition_ids == ["partition_1" , "partition_1" , "partition_1" ]
234- assert chunks [1 ].global_indexes == [1 , 5 , 9 ]
234+ assert chunks [1 ].global_indexes == [11 , 15 , 19 ]
235235 assert len (chunks [2 ]) == 2
236236 assert chunks [2 ].partition_ids == ["partition_2" , "partition_2" ]
237- assert chunks [2 ].global_indexes == [2 , 6 ]
237+ assert chunks [2 ].global_indexes == [12 , 16 ]
238238 assert len (chunks [3 ]) == 2
239239 assert chunks [3 ].partition_ids == ["partition_3" , "partition_3" ]
240- assert chunks [3 ].global_indexes == [3 , 7 ]
240+ assert chunks [3 ].global_indexes == [13 , 17 ]
241241
242242 # validate _custom_backend_meta is chunked
243- assert 0 in chunks [0 ].custom_meta
244- assert 4 in chunks [0 ].custom_meta
245- assert 8 in chunks [0 ].custom_meta
246- assert 1 not in chunks [0 ].custom_meta
247- assert 1 in chunks [1 ].custom_meta
243+ assert 10 in chunks [0 ].custom_meta
244+ assert 14 in chunks [0 ].custom_meta
245+ assert 18 in chunks [0 ].custom_meta
246+ assert 11 not in chunks [0 ].custom_meta
247+ assert 11 in chunks [1 ].custom_meta
248248
249249 # validate _custom_backend_meta is chunked
250- assert 0 in chunks [0 ]._custom_backend_meta
251- assert 4 in chunks [0 ]._custom_backend_meta
252- assert 8 in chunks [0 ]._custom_backend_meta
253- assert 1 not in chunks [0 ]._custom_backend_meta
254- assert 1 in chunks [1 ]._custom_backend_meta
250+ assert 10 in chunks [0 ]._custom_backend_meta
251+ assert 14 in chunks [0 ]._custom_backend_meta
252+ assert 18 in chunks [0 ]._custom_backend_meta
253+ assert 11 not in chunks [0 ]._custom_backend_meta
254+ assert 11 in chunks [1 ]._custom_backend_meta
255255
256256 def test_batch_meta_init_validation_error_different_field_names (self ):
257257 """Example: Init validation catches samples with different field names."""
0 commit comments