Skip to content

Commit 7d753f5

Browse files
committed
take torch num_workers into account for sharding
1 parent 653969c commit 7d753f5

2 files changed

Lines changed: 54 additions & 20 deletions

File tree

returnn/datasets/basic.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,6 @@ def set_or_remove(key, value):
6868
set_or_remove("min_chunk_size", config.opt_typed_value("min_chunk_size", 0) or None)
6969
set_or_remove("chunking_variance", config.float("chunking_variance", 0))
7070

71-
dd_cfg = config.typed_value("dataset_distribution", "random_seed_offset")
72-
assert dd_cfg in ["random_seed_offset", "shard"]
73-
shard_index, num_shards = Dataset._get_rank_and_size(config) if dd_cfg == "shard" else 0, 1
74-
set_or_remove("num_shards", num_shards)
75-
set_or_remove("shard_index", shard_index)
76-
7771
@staticmethod
7872
def get_default_kwargs_eval(config: Config) -> Dict[str, Any]:
7973
"""
@@ -118,8 +112,8 @@ def __init__(
118112
min_chunk_size=0,
119113
chunking_variance=0,
120114
estimated_num_seqs=None,
121-
num_shards: int = 1,
122-
shard_index: int = 0,
115+
num_shards: Optional[int] = None,
116+
shard_index: Optional[int] = None,
123117
):
124118
"""
125119
:param str name: e.g. "train" or "eval"
@@ -178,9 +172,9 @@ def __init__(
178172
self._chunking = chunking
179173
self.chunk_size, self.chunk_step, self.custom_chunking_func = self._parse_chunking(chunking)
180174
self._context_window = context_window
181-
assert 0 <= shard_index < num_shards
182-
self.num_shards = num_shards
183-
self.shard_index = shard_index
175+
assert (shard_index is None and num_shards is None) or 0 <= shard_index < num_shards
176+
self._num_shards = num_shards
177+
self._shard_index = shard_index
184178
if isinstance(context_window, (tuple, list)):
185179
assert len(context_window) == 2
186180
for elem in context_window:
@@ -219,7 +213,7 @@ def __repr__(self):
219213
getattr(self, "epoch", "<unknown>"),
220214
)
221215

222-
_getnewargs_exclude_attrs = set() # type: typing.Set[str]
216+
_getnewargs_exclude_attrs = {"num_shards", "shard_index"} # type: typing.Set[str]
223217
_getnewargs_remap = {} # type: typing.Dict[str,str]
224218

225219
@staticmethod
@@ -256,25 +250,51 @@ def __reduce__(self):
256250
state = {attr: getattr(self, attr) for attr in ["epoch", "zpad"]}
257251
return Dataset._create_from_reduce, (self.__class__, kwargs, state)
258252

253+
@property
254+
def num_shards(self) -> int:
255+
if self._num_shards is None:
256+
self._shard_index, self._num_shards = self._get_sharding_info()
257+
return self._num_shards
258+
259+
@property
260+
def shard_index(self) -> int:
261+
if self._shard_index is None:
262+
self._shard_index, self._num_shards = self._get_sharding_info()
263+
return self._shard_index
264+
259265
@staticmethod
260-
def _get_rank_and_size(config: Config) -> Tuple[int, int]:
266+
def _get_sharding_info(config: Optional[Config] = None) -> Tuple[int, int]:
261267
"""
268+
:param config: current RETURNN config, if not set, will fetch global
262269
:return: tuple (rank, size): the global rank and size for distributed trainings
263270
"""
264-
if config.typed_value("torch_distributed") is not None:
265-
import returnn.torch.distributed
271+
if config is None:
272+
from returnn.config import get_global_config
266273

267-
ctx = returnn.torch.distributed.get_ctx(config=config)
268-
return ctx.rank(), ctx.size()
269-
elif config.is_true("use_horovod"):
274+
config = get_global_config(return_empty_if_none=True)
275+
276+
if config.is_true("use_horovod"):
270277
assert config.bool("use_tensorflow", False) or config.value("backend", "").startswith("tensorflow")
271278

272279
import returnn.tf.horovod
273280

274281
ctx = returnn.tf.horovod.get_ctx(config=config)
275282
return ctx.rank(), ctx.size()
276-
else:
277-
return 0, 1
283+
284+
rank, size = 0, 1
285+
if config.typed_value("torch_distributed") is not None:
286+
import returnn.torch.distributed
287+
288+
ctx = returnn.torch.distributed.get_ctx(config=config)
289+
rank, size = ctx.rank(), ctx.size()
290+
if config.typed_value("torch_dataloader_opts") is not None:
291+
import torch.utils.data
292+
293+
worker_info = torch.utils.data.get_worker_info()
294+
if worker_info is not None:
295+
size *= worker_info.num_workers
296+
rank += worker_info.id
297+
return rank, size
278298

279299
@property
280300
def random_seed_offset(self) -> int:

tests/test_torch_dataset.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,20 @@ def test_MultiProcDataset_HDFDataset():
202202
assert c == n
203203

204204

205+
def test_dataset_num_workers_sharding():
206+
config = Config({"backend": "torch", "torch_dataloader_opts": {"num_workers": 2}})
207+
with global_config_ctx(config):
208+
datasets = [
209+
init_dataset({"class": "Task12AXDataset", "num_seqs": 100, "num_shards": 2, "shard_index": i})
210+
for i in range(2)
211+
]
212+
for dataset in datasets:
213+
assert isinstance(dataset, Task12AXDataset)
214+
dataset.init_seq_order(epoch=1)
215+
assert dataset.shard_index < dataset.num_shards == 4
216+
assert dataset.num_seqs == 25
217+
218+
205219
if __name__ == "__main__":
206220
better_exchook.install()
207221
if len(sys.argv) <= 1:

0 commit comments

Comments
 (0)