Skip to content

Commit 0748007

Browse files
authored
Add titanshard and 256k docbreak
1 parent b9fa538 commit 0748007

1 file changed

Lines changed: 60 additions & 39 deletions

File tree

fms_fsdp/utils/dataset_utils.py

Lines changed: 60 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,7 @@ def __init__(
875875
seed: int = 42,
876876
min_length: int = 1,
877877
max_chunksize: int = 1024,
878+
max_consecutive_chunks: int = 256,
878879
verbose: bool = False,
879880
):
880881
super().__init__(datapath, rank, worldsize)
@@ -887,6 +888,7 @@ def __init__(
887888
self.eos = delimiter_token
888889
self.bos = bos_token
889890
self.drop = strip_tokens
891+
self.max_consec = max_consecutive_chunks
890892
self.verbose = verbose
891893
self.docset: List[
892894
Any
@@ -902,6 +904,7 @@ def __init__(
902904
self.tokens_seen = 0
903905
self.docs_seen = 0
904906
self.percent_seen = 0
907+
self.consec = 0
905908

906909
self.state_params = [
907910
"dataset",
@@ -912,6 +915,7 @@ def __init__(
912915
"docs_seen",
913916
"percent_seen",
914917
"lcg_state",
918+
"consec",
915919
]
916920

917921
# Setup flags
@@ -942,75 +946,89 @@ def setup(self):
942946
for root, dirs, files in os.walk(datapath, topdown=False, followlinks=True)
943947
for name in files
944948
if self.filehandler.is_legal(os.path.join(root, name))
949+
and os.path.getsize(os.path.join(root, name)) > 1_000_000
950+
# 1mb minimum file size to prevent empty files
945951
]
946952
shards.sort() # Ensure consistent sharding across machines
947-
start_frag = (self.rank * self.worldsize * len(shards)) // self.worldsize
948-
end_frag = (
949-
(self.rank + 1) * self.worldsize * len(shards)
950-
) // self.worldsize
951-
shardfrags = [
952-
(shards[i // self.worldsize], i % self.worldsize)
953-
for i in range(start_frag, end_frag)
954-
]
955-
956-
# Assemble length of each owned shard file
957953

954+
# Find metadata file
958955
countfiles = []
959956
if os.path.exists(os.path.join(pardir, "meta")):
960957
countfiles = [
961958
x
962959
for x in os.listdir(os.path.join(pardir, "meta"))
963960
if "counts" in x and "csv" in x
964961
]
965-
doc_counts = {}
966962
if len(countfiles) > 0:
967963
# Count file exists, use it
968964
countpath = os.path.join(pardir, "meta", countfiles[0])
965+
else:
966+
countpath = ""
967+
968+
# Use shard file sizes to perform partitioning
969+
# Create shardlist of form shardid -> [start%, end%]
970+
if len(countfiles) > 0:
971+
sizes = {}
969972
with open(countpath, "r") as csvfile:
970973
reader = csv.DictReader(csvfile)
971974
for row in reader:
972975
fullpath = row["dataset/filename"]
973-
prefix = fullpath.find("/" + dataset) + 1
976+
prefix = fullpath.find(dataset + "/")
974977
if prefix > 0:
978+
key = fullpath[prefix + len(dataset) + 1 :]
979+
sizes[key] = int(row["size"])
980+
shard_sizes = [sizes[shard] for shard in shards]
981+
else:
982+
# Count file does not exist, touch every owned file for length
983+
shard_sizes = [
984+
os.path.getsize(os.path.join(datapath, shard)) for shard in shards
985+
]
986+
shard_sizes = [s / sum(shard_sizes) for s in shard_sizes]
987+
start = self.rank / self.worldsize
988+
end = (self.rank + 1) / self.worldsize
989+
shardset = {}
990+
tally = 0
991+
for i in range(len(shards)):
992+
if tally <= end and tally + shard_sizes[i] >= start:
993+
shardset[shards[i]] = [
994+
min(max((start - tally) / shard_sizes[i], 0), 1),
995+
min(max((end - tally) / shard_sizes[i], 0), 1),
996+
]
997+
tally += shard_sizes[i]
998+
999+
# Assemble length of each owned shard file
1000+
doc_counts = {}
1001+
if len(countfiles) > 0:
1002+
# Count file exists, use it
1003+
with open(countpath, "r") as csvfile:
1004+
reader = csv.DictReader(csvfile)
1005+
for row in reader:
1006+
fullpath = row["dataset/filename"]
1007+
prefix = fullpath.find(dataset + "/")
1008+
if prefix >= 0:
9751009
key = fullpath[prefix + len(dataset) + 1 :]
9761010
doc_counts[key] = int(row["documents"])
9771011
else:
9781012
# Count file does not exist, touch every owned file for length
979-
unique_shardfiles = set(shard for shard, frag in shardfrags)
9801013
doc_counts = {
9811014
shard: self.filehandler.length(os.path.join(datapath, shard))
982-
for shard in unique_shardfiles
1015+
for shard in shardset
9831016
}
9841017

985-
# Read shardfrags, assemble doc list for each file shard (aggregating over fragments):
986-
ndocs = -1
987-
docset = {} # shardid -> (min docid, max docid)
988-
for i, (shard, frag) in enumerate(shardfrags):
989-
ndocs = doc_counts[shard]
990-
doc_start = (ndocs * frag) // self.worldsize
991-
doc_end = (
992-
ndocs * frag + ndocs
993-
) // self.worldsize - 1 # Inclusive upper bound
994-
if shard not in docset:
995-
docset[shard] = [doc_start, doc_end]
996-
min_d, max_d = docset[shard]
997-
if doc_start < min_d:
998-
docset[shard][0] = doc_start
999-
if doc_end > max_d:
1000-
docset[shard][1] = doc_end
1001-
1002-
# Add shard entries to self.docset
1018+
# Assemble doc list for each file shard
1019+
# Create docset of form [shardid, min docid, max docid]
10031020
doccount = 0
1004-
for shardid in docset:
1005-
min_d = docset[shardid][0]
1006-
max_d = docset[shardid][1]
1007-
self.docset.append((shardid, min_d, max_d))
1008-
doccount += max_d - min_d + 1
1021+
for shard in shardset:
1022+
ndocs = doc_counts[shard]
1023+
doc_start = int(ndocs * shardset[shard][0])
1024+
doc_end = max(doc_start, int(ndocs * shardset[shard][1]) - 1) # inclusive upper bound
1025+
self.docset.append([shard, doc_start, doc_end])
1026+
doccount += doc_end - doc_start + 1
10091027
self._len = doccount
10101028

10111029
if self.verbose:
10121030
logging.info(
1013-
f" Worker {self.rank} ingested {len(shardfrags)} shard fragments from {dataset}"
1031+
f" Worker {self.rank} ingested {len(self.docset)} shard fragments from {dataset}"
10141032
)
10151033

10161034
# Shuffle shard files - guaranteed inconsistent across workers
@@ -1065,8 +1083,11 @@ def _construct_chunk(self, j, doc, n_chunks):
10651083
# Add bos/eos tokens if needed
10661084
if self.bos is not None and j == 0:
10671085
chunk = [self.bos] + chunk
1068-
if j == n_chunks - 1:
1086+
if j == n_chunks - 1 or self.consec == self.max_consec:
10691087
chunk = chunk + [self.eos]
1088+
self.consec = 0
1089+
else:
1090+
self.consec += 1
10701091
return chunk
10711092

10721093
def _random_map_docid(self, size):

0 commit comments

Comments
 (0)