Skip to content

Commit 1feca7c

Browse files
authored
Fix docs remaining check, handle rescale edge case
1 parent d54063e commit 1feca7c

1 file changed

Lines changed: 7 additions & 3 deletions

File tree

fms_fsdp/utils/dataset_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,21 +1389,25 @@ def setup(self):
13891389
)
13901390
[d.setup() for d in self.data]
13911391
self.n_docs_remaining = [d._len for d in self.data]
1392+
assert (
1393+
sum(self.n_docs_remaining) > 0
1394+
), f"No documents detected in shard {self.rank} of {self.datapath}"
13921395

13931396
self.generator = torch.Generator().manual_seed(self.rank)
13941397

13951398
def __iter__(self):
13961399
self.setup()
13971400
# Grab one doc at a time in random order
13981401
data = [iter(d) for d in self.data]
1402+
# Reset if we're rescaling into a prematurely finished epoch
1403+
# (i.e. [1,1,0,0,0,0] into [1,1,0] [0,0,0] )
1404+
if sum(self.n_docs_remaining) == 0:
1405+
self.n_docs_remaining = [d._len for d in self.data]
13991406
while True:
14001407
# Sample logical shard (or load from ckp)
14011408
if self.current_reader is not None:
14021409
ind = self.current_reader
14031410
else:
1404-
assert (
1405-
sum(self.n_docs_remaining) > 0
1406-
), f"No documents detected in {self.datapath}"
14071411
ind = torch.multinomial(
14081412
torch.tensor(self.n_docs_remaining, dtype=torch.float),
14091413
1,

0 commit comments

Comments
 (0)