Skip to content

Commit e519d2d

Browse files
refactor: refine metadata comments and improve field extraction logic (#63) (#64)
Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com> Co-authored-by: Jianjun Zhong <87791082+jianjunzhong@users.noreply.github.com>
1 parent d32cc08 commit e519d2d

1 file changed

Lines changed: 13 additions & 11 deletions

File tree

transfer_queue/metadata.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ class FieldMeta:
3232
name: str
3333

3434
# data schema info
35-
dtype: Optional[Any] # e.g., torch.float32, np, etc.
36-
shape: Optional[Any] # e.g., torch.Size([seq_len]), torch.Size([seq_len, feature_dim]), etc.
35+
dtype: Optional[Any] # if data has dtype attribute, e.g., torch.float32, numpy.float32, etc.
36+
shape: Optional[Any] # if data has shape attribute, e.g., torch.Size([3, 224, 224]), (3, 224, 224), etc.
3737

3838
# data status info
3939
production_status: ProductionStatus = ProductionStatus.NOT_PRODUCED # production status for this field
@@ -354,12 +354,13 @@ def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "Ba
354354
set_all_ready (bool): If True, set all production_status to READY_FOR_CONSUME. Default is True.
355355
"""
356356
fields = _extract_field_metas(tensor_dict, set_all_ready)
357-
for idx, sample in enumerate(self.samples):
358-
sample.add_fields(fields=fields[idx])
357+
if len(fields) > 0:
358+
for idx, sample in enumerate(self.samples):
359+
sample.add_fields(fields=fields[idx])
359360

360-
# Update batch-level fields cache
361-
object.__setattr__(self, "_field_names", sorted(self.samples[0].field_names))
362-
object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples))
361+
# Update batch-level fields cache
362+
object.__setattr__(self, "_field_names", sorted(self.samples[0].field_names))
363+
object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples))
363364
return self
364365

365366
def __len__(self) -> int:
@@ -508,9 +509,10 @@ def _update_after_reorder(self) -> None:
508509
object.__setattr__(self, "_local_indexes", [sample.local_index for sample in self.samples])
509510
object.__setattr__(self, "_storage_ids", [sample.storage_id for sample in self.samples])
510511

511-
# Rebuild storage groups
512-
storage_meta_groups = self._build_storage_meta_groups()
513-
object.__setattr__(self, "_storage_meta_groups", storage_meta_groups)
512+
# Note: No need to rebuild storage_meta_groups as samples' storage_id remain unchanged
513+
# and their order does not affect the grouping
514+
# storage_meta_groups = self._build_storage_meta_groups()
515+
# object.__setattr__(self, "_storage_meta_groups", storage_meta_groups)
514516

515517
# Note: No need to update _size, _field_names, _is_ready, etc., as these remain unchanged after reorder
516518

@@ -580,7 +582,7 @@ def _extract_field_metas(tensor_dict: TensorDict, set_all_ready: bool = True) ->
580582
Otherwise, set to NOT_PRODUCED. Default is True.
581583
582584
Returns:
583-
all_fields (list[dict[FieldMeta]]): A list of dictionaries containing field metadata.
585+
all_fields (list[dict[str, FieldMeta]]): A list of dictionaries containing field metadata.
584586
"""
585587
all_fields = []
586588
batch_size = tensor_dict.batch_size[0]

0 commit comments

Comments
 (0)