@@ -875,6 +875,7 @@ def __init__(
875875 seed : int = 42 ,
876876 min_length : int = 1 ,
877877 max_chunksize : int = 1024 ,
878+ max_consecutive_chunks : int = 256 ,
878879 verbose : bool = False ,
879880 ):
880881 super ().__init__ (datapath , rank , worldsize )
@@ -887,6 +888,7 @@ def __init__(
887888 self .eos = delimiter_token
888889 self .bos = bos_token
889890 self .drop = strip_tokens
891+ self .max_consec = max_consecutive_chunks
890892 self .verbose = verbose
891893 self .docset : List [
892894 Any
@@ -902,6 +904,7 @@ def __init__(
902904 self .tokens_seen = 0
903905 self .docs_seen = 0
904906 self .percent_seen = 0
907+ self .consec = 0
905908
906909 self .state_params = [
907910 "dataset" ,
@@ -912,6 +915,7 @@ def __init__(
912915 "docs_seen" ,
913916 "percent_seen" ,
914917 "lcg_state" ,
918+ "consec" ,
915919 ]
916920
917921 # Setup flags
@@ -942,75 +946,89 @@ def setup(self):
942946 for root , dirs , files in os .walk (datapath , topdown = False , followlinks = True )
943947 for name in files
944948 if self .filehandler .is_legal (os .path .join (root , name ))
949+ and os .path .getsize (os .path .join (root , name )) > 1_000_000
950+ # 1mb minimum file size to prevent empty files
945951 ]
946952 shards .sort () # Ensure consistent sharding across machines
947- start_frag = (self .rank * self .worldsize * len (shards )) // self .worldsize
948- end_frag = (
949- (self .rank + 1 ) * self .worldsize * len (shards )
950- ) // self .worldsize
951- shardfrags = [
952- (shards [i // self .worldsize ], i % self .worldsize )
953- for i in range (start_frag , end_frag )
954- ]
955-
956- # Assemble length of each owned shard file
957953
954+ # Find metadata file
958955 countfiles = []
959956 if os .path .exists (os .path .join (pardir , "meta" )):
960957 countfiles = [
961958 x
962959 for x in os .listdir (os .path .join (pardir , "meta" ))
963960 if "counts" in x and "csv" in x
964961 ]
965- doc_counts = {}
966962 if len (countfiles ) > 0 :
967963 # Count file exists, use it
968964 countpath = os .path .join (pardir , "meta" , countfiles [0 ])
965+ else :
966+ countpath = ""
967+
968+ # Use shard file sizes to perform partitioning
969+ # Create shardlist of form shardid -> [start%, end%]
970+ if len (countfiles ) > 0 :
971+ sizes = {}
969972 with open (countpath , "r" ) as csvfile :
970973 reader = csv .DictReader (csvfile )
971974 for row in reader :
972975 fullpath = row ["dataset/filename" ]
973- prefix = fullpath .find ("/" + dataset ) + 1
976+ prefix = fullpath .find (dataset + "/" )
974977 if prefix > 0 :
978+ key = fullpath [prefix + len (dataset ) + 1 :]
979+ sizes [key ] = int (row ["size" ])
980+ shard_sizes = [sizes [shard ] for shard in shards ]
981+ else :
982+ # Count file does not exist, touch every owned file for length
983+ shard_sizes = [
984+ os .path .getsize (os .path .join (datapath , shard )) for shard in shards
985+ ]
986+ shard_sizes = [s / sum (shard_sizes ) for s in shard_sizes ]
987+ start = self .rank / self .worldsize
988+ end = (self .rank + 1 ) / self .worldsize
989+ shardset = {}
990+ tally = 0
991+ for i in range (len (shards )):
992+ if tally <= end and tally + shard_sizes [i ] >= start :
993+ shardset [shards [i ]] = [
994+ min (max ((start - tally ) / shard_sizes [i ], 0 ), 1 ),
995+ min (max ((end - tally ) / shard_sizes [i ], 0 ), 1 ),
996+ ]
997+ tally += shard_sizes [i ]
998+
999+ # Assemble length of each owned shard file
1000+ doc_counts = {}
1001+ if len (countfiles ) > 0 :
1002+ # Count file exists, use it
1003+ with open (countpath , "r" ) as csvfile :
1004+ reader = csv .DictReader (csvfile )
1005+ for row in reader :
1006+ fullpath = row ["dataset/filename" ]
1007+ prefix = fullpath .find (dataset + "/" )
1008+ if prefix >= 0 :
9751009 key = fullpath [prefix + len (dataset ) + 1 :]
9761010 doc_counts [key ] = int (row ["documents" ])
9771011 else :
9781012 # Count file does not exist, touch every owned file for length
979- unique_shardfiles = set (shard for shard , frag in shardfrags )
9801013 doc_counts = {
9811014 shard : self .filehandler .length (os .path .join (datapath , shard ))
982- for shard in unique_shardfiles
1015+ for shard in shardset
9831016 }
9841017
985- # Read shardfrags, assemble doc list for each file shard (aggregating over fragments):
986- ndocs = - 1
987- docset = {} # shardid -> (min docid, max docid)
988- for i , (shard , frag ) in enumerate (shardfrags ):
989- ndocs = doc_counts [shard ]
990- doc_start = (ndocs * frag ) // self .worldsize
991- doc_end = (
992- ndocs * frag + ndocs
993- ) // self .worldsize - 1 # Inclusive upper bound
994- if shard not in docset :
995- docset [shard ] = [doc_start , doc_end ]
996- min_d , max_d = docset [shard ]
997- if doc_start < min_d :
998- docset [shard ][0 ] = doc_start
999- if doc_end > max_d :
1000- docset [shard ][1 ] = doc_end
1001-
1002- # Add shard entries to self.docset
1018+ # Assemble doc list for each file shard
1019+ # Create docset of form [shardid, min docid, max docid]
10031020 doccount = 0
1004- for shardid in docset :
1005- min_d = docset [shardid ][0 ]
1006- max_d = docset [shardid ][1 ]
1007- self .docset .append ((shardid , min_d , max_d ))
1008- doccount += max_d - min_d + 1
1021+ for shard in shardset :
1022+ ndocs = doc_counts [shard ]
1023+ doc_start = int (ndocs * shardset [shard ][0 ])
1024+ doc_end = max (doc_start , int (ndocs * shardset [shard ][1 ]) - 1 ) # inclusive upper bound
1025+ self .docset .append ([shard , doc_start , doc_end ])
1026+ doccount += doc_end - doc_start + 1
10091027 self ._len = doccount
10101028
10111029 if self .verbose :
10121030 logging .info (
1013- f" Worker { self .rank } ingested { len (shardfrags )} shard fragments from { dataset } "
1031+ f" Worker { self .rank } ingested { len (self . docset )} shard fragments from { dataset } "
10141032 )
10151033
10161034 # Shuffle shard files - guaranteed inconsistent across workers
@@ -1065,8 +1083,11 @@ def _construct_chunk(self, j, doc, n_chunks):
10651083 # Add bos/eos tokens if needed
10661084 if self .bos is not None and j == 0 :
10671085 chunk = [self .bos ] + chunk
1068- if j == n_chunks - 1 :
1086+ if j == n_chunks - 1 or self . consec == self . max_consec :
10691087 chunk = chunk + [self .eos ]
1088+ self .consec = 0
1089+ else :
1090+ self .consec += 1
10701091 return chunk
10711092
10721093 def _random_map_docid (self , size ):
0 commit comments