@@ -113,7 +113,9 @@ def __iter__(self) -> Iterator[int]:
113113 # subsample
114114 indices = indices [self .step + self .rank : self .total_size : self .world_size ]
115115
116- yield from iter (indices )
116+ for idx in indices :
117+ self ._consumed .record (1 )
118+ yield idx
117119 self .step = 0
118120
119121 def __len__ (self ) -> int :
@@ -133,9 +135,6 @@ def set_epoch(self, epoch: int) -> None:
133135 """
134136 self .epoch = epoch
135137
136- def record_consumed_samples (self , n : int ) -> None :
137- self ._consumed .record (n )
138-
139138 def get_total_consumed_steps (self ) -> int :
140139 return self ._consumed .total_for_checkpoint ()
141140
@@ -275,7 +274,9 @@ def __iter__(self) -> Iterator[int]:
275274 assert len (indices ) == self .total_size
276275 indices = indices [self .step + self .rank : self .total_size : self .world_size ]
277276 assert len (indices ) == self .num_samples - self .step // self .world_size
278- yield from iter (indices )
277+ for idx in indices :
278+ self ._consumed .record (1 )
279+ yield idx
279280 self .step = 0
280281
281282 def __len__ (self ) -> int :
@@ -294,9 +295,6 @@ def set_epoch(self, epoch: int) -> None:
294295 """
295296 self .epoch = epoch
296297
297- def record_consumed_samples (self , n : int ) -> None :
298- self ._consumed .record (n )
299-
300298 def get_total_consumed_steps (self ) -> int :
301299 return self ._consumed .total_for_checkpoint ()
302300
0 commit comments