Skip to content

Commit f422aeb

Browse files
committed
fix
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
1 parent bc34a6a commit f422aeb

2 files changed

Lines changed: 37 additions & 29 deletions

File tree

tests/test_metadata.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,11 @@ def test_batch_meta_chunk_by_partition(self):
215215
name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME
216216
)
217217
}
218-
samples = [SampleMeta(partition_id=f"partition_{i % 4}", global_index=i, fields=fields) for i in range(10)]
218+
samples = [SampleMeta(partition_id=f"partition_{i % 4}", global_index=i + 10, fields=fields) for i in range(10)]
219219
batch = BatchMeta(
220220
samples=samples,
221-
custom_meta={i: {"uid": i} for i in range(10)},
222-
_custom_backend_meta={i: {"test_field": {"dtype": torch.float32}} for i in range(10)},
221+
custom_meta={i + 10: {"uid": i + 10} for i in range(10)},
222+
_custom_backend_meta={i + 10: {"test_field": {"dtype": torch.float32}} for i in range(10)},
223223
)
224224

225225
# Chunk according to partition_id
@@ -228,30 +228,30 @@ def test_batch_meta_chunk_by_partition(self):
228228
assert len(chunks) == 4
229229
assert len(chunks[0]) == 3
230230
assert chunks[0].partition_ids == ["partition_0", "partition_0", "partition_0"]
231-
assert chunks[0].global_indexes == [0, 4, 8]
231+
assert chunks[0].global_indexes == [10, 14, 18]
232232
assert len(chunks[1]) == 3
233233
assert chunks[1].partition_ids == ["partition_1", "partition_1", "partition_1"]
234-
assert chunks[1].global_indexes == [1, 5, 9]
234+
assert chunks[1].global_indexes == [11, 15, 19]
235235
assert len(chunks[2]) == 2
236236
assert chunks[2].partition_ids == ["partition_2", "partition_2"]
237-
assert chunks[2].global_indexes == [2, 6]
237+
assert chunks[2].global_indexes == [12, 16]
238238
assert len(chunks[3]) == 2
239239
assert chunks[3].partition_ids == ["partition_3", "partition_3"]
240-
assert chunks[3].global_indexes == [3, 7]
240+
assert chunks[3].global_indexes == [13, 17]
241241

242242
# validate _custom_backend_meta is chunked
243-
assert 0 in chunks[0].custom_meta
244-
assert 4 in chunks[0].custom_meta
245-
assert 8 in chunks[0].custom_meta
246-
assert 1 not in chunks[0].custom_meta
247-
assert 1 in chunks[1].custom_meta
243+
assert 10 in chunks[0].custom_meta
244+
assert 14 in chunks[0].custom_meta
245+
assert 18 in chunks[0].custom_meta
246+
assert 11 not in chunks[0].custom_meta
247+
assert 11 in chunks[1].custom_meta
248248

249249
# validate _custom_backend_meta is chunked
250-
assert 0 in chunks[0]._custom_backend_meta
251-
assert 4 in chunks[0]._custom_backend_meta
252-
assert 8 in chunks[0]._custom_backend_meta
253-
assert 1 not in chunks[0]._custom_backend_meta
254-
assert 1 in chunks[1]._custom_backend_meta
250+
assert 10 in chunks[0]._custom_backend_meta
251+
assert 14 in chunks[0]._custom_backend_meta
252+
assert 18 in chunks[0]._custom_backend_meta
253+
assert 11 not in chunks[0]._custom_backend_meta
254+
assert 11 in chunks[1]._custom_backend_meta
255255

256256
def test_batch_meta_init_validation_error_different_field_names(self):
257257
"""Example: Init validation catches samples with different field names."""

transfer_queue/metadata.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -409,13 +409,10 @@ def select_samples(self, indexes: list[int]) -> "BatchMeta":
409409

410410
selected_samples = [self.samples[i] for i in indexes]
411411

412-
selected_custom_meta = {
413-
i: self.custom_meta[self.global_indexes[i]] for i in indexes if self.global_indexes[i] in self.custom_meta
414-
}
412+
global_indexes = [self.global_indexes[i] for i in indexes]
413+
selected_custom_meta = {i: self.custom_meta[i] for i in global_indexes if i in self.custom_meta}
415414
selected_custom_backend_meta = {
416-
i: self._custom_backend_meta[self.global_indexes[i]]
417-
for i in indexes
418-
if self.global_indexes[i] in self._custom_backend_meta
415+
i: self._custom_backend_meta[i] for i in global_indexes if i in self._custom_backend_meta
419416
}
420417

421418
# construct new BatchMeta instance
@@ -470,11 +467,22 @@ def __getitem__(self, item):
470467
if isinstance(item, int | np.integer):
471468
sample_meta = self.samples[item] if self.samples else []
472469
global_idx = self.global_indexes[item]
470+
471+
if global_idx in self.custom_meta:
472+
custom_meta = {global_idx: self.custom_meta[global_idx]}
473+
else:
474+
custom_meta = {}
475+
476+
if global_idx in self._custom_backend_meta:
477+
custom_backend_meta = {global_idx: self._custom_backend_meta[global_idx]}
478+
else:
479+
custom_backend_meta = {}
480+
473481
return BatchMeta(
474482
samples=[sample_meta],
475483
extra_info=self.extra_info,
476-
custom_meta={global_idx: self.custom_meta[global_idx]},
477-
_custom_backend_meta={global_idx: self._custom_backend_meta[global_idx]},
484+
custom_meta=custom_meta,
485+
_custom_backend_meta=custom_backend_meta,
478486
)
479487
else:
480488
raise TypeError(f"Indexing with {type(item)} is not supported now!")
@@ -533,11 +541,11 @@ def chunk_by_partition(
533541
List of smaller BatchMeta chunks, each chunk has samples with identical partition_id
534542
"""
535543

536-
grouped_global_indexes = defaultdict(list)
537-
for partition_id, global_index in zip(self.partition_ids, self.global_indexes, strict=False):
538-
grouped_global_indexes[partition_id].append(global_index)
544+
grouped_indexes = defaultdict(list)
545+
for partition_id, indexes in zip(self.partition_ids, range(self.size), strict=False):
546+
grouped_indexes[partition_id].append(indexes)
539547

540-
chunk_list = [self.select_samples(global_indices) for global_indices in grouped_global_indexes.values()]
548+
chunk_list = [self.select_samples(idx) for idx in grouped_indexes.values()]
541549

542550
return chunk_list
543551

0 commit comments

Comments
 (0)