@@ -32,8 +32,8 @@ class FieldMeta:
3232 name : str
3333
3434 # data schema info
35- dtype : Optional [Any ] # e.g., torch.float32, np , etc.
36- shape : Optional [Any ] # e.g., torch.Size([seq_len ]), torch.Size([seq_len, feature_dim] ), etc.
35+ dtype : Optional [Any ] # if data has dtype attribute, e.g., torch.float32, numpy.float32 , etc.
36+ shape : Optional [Any ] # if data has shape attribute, e.g., torch.Size([3, 224, 224 ]), (3, 224, 224 ), etc.
3737
3838 # data status info
3939 production_status : ProductionStatus = ProductionStatus .NOT_PRODUCED # production status for this field
@@ -354,12 +354,13 @@ def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "Ba
354354 set_all_ready (bool): If True, set all production_status to READY_FOR_CONSUME. Default is True.
355355 """
356356 fields = _extract_field_metas (tensor_dict , set_all_ready )
357- for idx , sample in enumerate (self .samples ):
358- sample .add_fields (fields = fields [idx ])
357+ if len (fields ) > 0 :
358+ for idx , sample in enumerate (self .samples ):
359+ sample .add_fields (fields = fields [idx ])
359360
360- # Update batch-level fields cache
361- object .__setattr__ (self , "_field_names" , sorted (self .samples [0 ].field_names ))
362- object .__setattr__ (self , "_is_ready" , all (sample .is_ready for sample in self .samples ))
361+ # Update batch-level fields cache
362+ object .__setattr__ (self , "_field_names" , sorted (self .samples [0 ].field_names ))
363+ object .__setattr__ (self , "_is_ready" , all (sample .is_ready for sample in self .samples ))
363364 return self
364365
365366 def __len__ (self ) -> int :
@@ -508,9 +509,10 @@ def _update_after_reorder(self) -> None:
508509 object .__setattr__ (self , "_local_indexes" , [sample .local_index for sample in self .samples ])
509510 object .__setattr__ (self , "_storage_ids" , [sample .storage_id for sample in self .samples ])
510511
511- # Rebuild storage groups
512- storage_meta_groups = self ._build_storage_meta_groups ()
513- object .__setattr__ (self , "_storage_meta_groups" , storage_meta_groups )
512+ # Note: No need to rebuild storage_meta_groups as samples' storage_id remain unchanged
513+ # and their order does not affect the grouping
514+ # storage_meta_groups = self._build_storage_meta_groups()
515+ # object.__setattr__(self, "_storage_meta_groups", storage_meta_groups)
514516
515517 # Note: No need to update _size, _field_names, _is_ready, etc., as these remain unchanged after reorder
516518
@@ -580,7 +582,7 @@ def _extract_field_metas(tensor_dict: TensorDict, set_all_ready: bool = True) ->
580582 Otherwise, set to NOT_PRODUCED. Default is True.
581583
582584 Returns:
583- all_fields (list[dict[FieldMeta]]): A list of dictionaries containing field metadata.
585+ all_fields (list[dict[str, FieldMeta]]): A list of dictionaries containing field metadata.
584586 """
585587 all_fields = []
586588 batch_size = tensor_dict .batch_size [0 ]
0 commit comments