@@ -68,12 +68,6 @@ def set_or_remove(key, value):
6868 set_or_remove ("min_chunk_size" , config .opt_typed_value ("min_chunk_size" , 0 ) or None )
6969 set_or_remove ("chunking_variance" , config .float ("chunking_variance" , 0 ))
7070
71- dd_cfg = config .typed_value ("dataset_distribution" , "random_seed_offset" )
72- assert dd_cfg in ["random_seed_offset" , "shard" ]
73- shard_index , num_shards = Dataset ._get_rank_and_size (config ) if dd_cfg == "shard" else 0 , 1
74- set_or_remove ("num_shards" , num_shards )
75- set_or_remove ("shard_index" , shard_index )
76-
7771 @staticmethod
7872 def get_default_kwargs_eval (config : Config ) -> Dict [str , Any ]:
7973 """
@@ -118,8 +112,8 @@ def __init__(
118112 min_chunk_size = 0 ,
119113 chunking_variance = 0 ,
120114 estimated_num_seqs = None ,
121- num_shards : int = 1 ,
122- shard_index : int = 0 ,
115+ num_shards : Optional [ int ] = None ,
116+ shard_index : Optional [ int ] = None ,
123117 ):
124118 """
125119 :param str name: e.g. "train" or "eval"
@@ -178,9 +172,9 @@ def __init__(
178172 self ._chunking = chunking
179173 self .chunk_size , self .chunk_step , self .custom_chunking_func = self ._parse_chunking (chunking )
180174 self ._context_window = context_window
181- assert 0 <= shard_index < num_shards
182- self .num_shards = num_shards
183- self .shard_index = shard_index
175+ assert ( shard_index is None and num_shards is None ) or 0 <= shard_index < num_shards
176+ self ._num_shards = num_shards
177+ self ._shard_index = shard_index
184178 if isinstance (context_window , (tuple , list )):
185179 assert len (context_window ) == 2
186180 for elem in context_window :
@@ -219,7 +213,7 @@ def __repr__(self):
219213 getattr (self , "epoch" , "<unknown>" ),
220214 )
221215
222- _getnewargs_exclude_attrs = set () # type: typing.Set[str]
216+ _getnewargs_exclude_attrs = { "num_shards" , "shard_index" } # type: typing.Set[str]
223217 _getnewargs_remap = {} # type: typing.Dict[str,str]
224218
225219 @staticmethod
@@ -256,25 +250,51 @@ def __reduce__(self):
256250 state = {attr : getattr (self , attr ) for attr in ["epoch" , "zpad" ]}
257251 return Dataset ._create_from_reduce , (self .__class__ , kwargs , state )
258252
253+ @property
254+ def num_shards (self ) -> int :
255+ if self ._num_shards is None :
256+ self ._shard_index , self ._num_shards = self ._get_sharding_info ()
257+ return self ._num_shards
258+
259+ @property
260+ def shard_index (self ) -> int :
261+ if self ._shard_index is None :
262+ self ._shard_index , self ._num_shards = self ._get_sharding_info ()
263+ return self ._shard_index
264+
259265 @staticmethod
260- def _get_rank_and_size (config : Config ) -> Tuple [int , int ]:
266+ def _get_sharding_info (config : Optional [ Config ] = None ) -> Tuple [int , int ]:
261267 """
268+ :param config: current RETURNN config, if not set, will fetch global
262269 :return: tuple (rank, size): the global rank and size for distributed trainings
263270 """
264- if config . typed_value ( "torch_distributed" ) is not None :
265- import returnn .torch . distributed
271+ if config is None :
272+ from returnn .config import get_global_config
266273
267- ctx = returnn . torch . distributed . get_ctx ( config = config )
268- return ctx . rank (), ctx . size ()
269- elif config .is_true ("use_horovod" ):
274+ config = get_global_config ( return_empty_if_none = True )
275+
276+ if config .is_true ("use_horovod" ):
270277 assert config .bool ("use_tensorflow" , False ) or config .value ("backend" , "" ).startswith ("tensorflow" )
271278
272279 import returnn .tf .horovod
273280
274281 ctx = returnn .tf .horovod .get_ctx (config = config )
275282 return ctx .rank (), ctx .size ()
276- else :
277- return 0 , 1
283+
284+ rank , size = 0 , 1
285+ if config .typed_value ("torch_distributed" ) is not None :
286+ import returnn .torch .distributed
287+
288+ ctx = returnn .torch .distributed .get_ctx (config = config )
289+ rank , size = ctx .rank (), ctx .size ()
290+ if config .typed_value ("torch_dataloader_opts" ) is not None :
291+ import torch .utils .data
292+
293+ worker_info = torch .utils .data .get_worker_info ()
294+ if worker_info is not None :
295+ size *= worker_info .num_workers
296+ rank += worker_info .id
297+ return rank , size
278298
279299 @property
280300 def random_seed_offset (self ) -> int :
0 commit comments