Skip to content

Commit e852396

Browse files
committed
torch distributed: add support for user-specified parameter synchronization
1 parent eb0f22e commit e852396

1 file changed

Lines changed: 121 additions & 13 deletions

File tree

returnn/torch/distributed.py

Lines changed: 121 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,78 @@
33
"""
44

55
from __future__ import annotations
6-
from typing import Optional, Any, Dict
6+
from abc import abstractmethod, ABC
7+
import logging
8+
import numpy
79
import os
810
import socket
9-
import logging
11+
from typing import Callable, Optional, Any, Dict, Type, Union
1012

1113
import torch
1214
from torch.nn.parallel import DistributedDataParallel
1315

14-
from returnn.config import Config
15-
from returnn.util.basic import CollectionReadCheckCovered
16+
from returnn.util.basic import CollectionReadCheckCovered, OptionalNotImplementedError
1617

1718
_logger = logging.getLogger("returnn.torch.distributed")
1819

1920

21+
class ParamSynchronizer(ABC):
22+
"""
23+
Custom parameter synchronization primitive.
24+
25+
Contains a callback that is called after every train step to synchronize model parameters
26+
across processes/nodes.
27+
"""
28+
29+
@abstractmethod
30+
def __init__(self, *, rank: int, size: int, local_rank: int, local_size: int, **kwargs):
31+
"""
32+
`__init__` called after the default global process group is created.
33+
Can be used to initialize any additional custom process (sub)groups.
34+
35+
Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatibility.
36+
37+
:param rank: global rank of the current process across all nodes
38+
:param size: global world size across all nodes
39+
:param local_rank: local rank of the current process on the current node
40+
:param local_rank: local world size on the current node
41+
:param kwargs: any additional kwargs
42+
"""
43+
super().__init__()
44+
45+
self.rank = rank
46+
self.size = size
47+
self.local_rank = local_rank
48+
self.local_size = local_size
49+
50+
def make_distributed_model(self, *, module: torch.nn.Module, **kwargs) -> DistributedDataParallel:
51+
"""
52+
Creates an associated `DistributedDataParallel` for the given module for gradient synchronization.
53+
54+
This function can be left unimplemented if no gradient synchronization is done.
55+
56+
Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatibility.
57+
"""
58+
raise OptionalNotImplementedError
59+
60+
@abstractmethod
61+
def step(self, *, module: torch.nn.Module, train_step_idx: int, **kwargs):
62+
"""
63+
Parameter synchronization callback called after every train step with updated model parameters.
64+
65+
Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatibility.
66+
67+
:param module: the NN being trained
68+
:param train_step_idx: the current train step
69+
:param kwargs: any additional kwargs
70+
"""
71+
raise NotImplementedError
72+
73+
def __call__(self, *args, **kwargs):
74+
"""forwards to :func:``step``"""
75+
return self.step(*args, **kwargs)
76+
77+
2078
class DistributedContext:
2179
"""
2280
This class setups some helper functions for torch distributed training
@@ -26,6 +84,9 @@ def __init__(self, options: Dict[str, Any]):
2684
import torch.distributed as dist
2785

2886
self._opts = CollectionReadCheckCovered(options)
87+
# Only used to generate forwards compatibility ensuring random kwargs, therefore
88+
# the seed is not important
89+
self._rng = numpy.random.default_rng()
2990

3091
# when no backend is specified, both gloo and nccl backends will be created
3192
# the gloo backend will be used for collectives with CPU tensors and
@@ -42,8 +103,13 @@ def __init__(self, options: Dict[str, Any]):
42103
% (socket.gethostname(), os.getpid(), self._rank, self._size, self._local_rank, self._local_size)
43104
)
44105

106+
self._custom_sync_class: Optional[Union[Callable, Type[ParamSynchronizer]]] = self._opts.get(
107+
"synchronizer", None
108+
)
109+
self._custom_sync: Optional[Callable] = None
45110
self._reduce_type = self._opts.get("reduce_type", "grad")
46111
self._param_sync_step: Optional[int] = self._opts.get("param_sync_step", None)
112+
47113
if self._reduce_type == "param":
48114
assert isinstance(self._param_sync_step, int) and self._param_sync_step > 0, (
49115
f"reduce_type param: param_sync_step must be a positive int,"
@@ -52,6 +118,23 @@ def __init__(self, options: Dict[str, Any]):
52118
_logger.info(f"reduce_type param: param_sync_step {self._param_sync_step}")
53119
elif self._reduce_type == "grad":
54120
_logger.info("reduce_type grad")
121+
elif self._reduce_type == "custom":
122+
if issubclass(self._custom_sync_class, ParamSynchronizer):
123+
self._custom_sync = self._custom_sync_class(
124+
rank=self._rank,
125+
size=self._size,
126+
local_rank=self._local_rank,
127+
local_size=self._local_size,
128+
**{f"fwd_compatible_random_kwarg_{self._rng.integers(0, 100)}": None},
129+
)
130+
elif isinstance(self._custom_sync_class, Callable):
131+
self._custom_sync = self._custom_sync_class
132+
else:
133+
raise ValueError(
134+
f"synchronizer must either be a callable or a class inheriting from {ParamSynchronizer.__name__}"
135+
)
136+
137+
_logger.info(f"reduce_type custom: {type(self._custom_sync)}")
55138
else:
56139
raise ValueError(f"invalid reduce_type {self._reduce_type!r}")
57140

@@ -70,6 +153,8 @@ def _check_no_unknown_opts(self):
70153
self._opts.get("options")
71154
if self._reduce_type == "param":
72155
self._opts.get("sync_on_cpu")
156+
if self._reduce_type == "custom":
157+
self._opts.get("synchronizer")
73158

74159
self._opts.assert_all_read()
75160

@@ -102,7 +187,24 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis
102187
"""
103188
if self._reduce_type == "param":
104189
return None
105-
assert self._reduce_type == "grad"
190+
assert self._reduce_type in ["custom", "grad"]
191+
192+
if self._reduce_type == "custom":
193+
assert isinstance(self._custom_sync, (ParamSynchronizer, Callable))
194+
195+
if isinstance(self._custom_sync, ParamSynchronizer):
196+
try:
197+
return self._custom_sync.make_distributed_model(
198+
module=module, **{f"fwd_compatible_random_kwarg_{self._rng.integers(0, 100)}": None}
199+
)
200+
except OptionalNotImplementedError:
201+
pass
202+
else:
203+
# callable short form does not have support for DistributedDataParallel
204+
pass
205+
206+
return None
207+
106208
cls = self._opts.get("class", DistributedDataParallel)
107209
if cls is not DistributedDataParallel:
108210
_logger.warning(f"Using custom class {cls} instead of DistributedDataParallel, might be unsupported.")
@@ -115,7 +217,14 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis
115217

116218
def step_after_param_update(self, *, module: torch.nn.Module, epoch_step_idx: int):
117219
"""one train step"""
118-
if self._reduce_type == "param" and ((epoch_step_idx % self._param_sync_step) == (self._param_sync_step - 1)):
220+
if self._reduce_type == "custom":
221+
with torch.no_grad(): # TODO: do we want this for all syncers?
222+
self._custom_sync(
223+
module=module,
224+
train_step_idx=epoch_step_idx,
225+
**{f"fwd_compatible_random_kwarg_{self._rng.integers(0, 100)}": None},
226+
)
227+
elif self._reduce_type == "param" and ((epoch_step_idx % self._param_sync_step) == (self._param_sync_step - 1)):
119228
_sync_params_avg(module=module, sync_on_cpu=self._opts.get("sync_on_cpu", False))
120229

121230

@@ -127,7 +236,7 @@ def get_ctx(config=None) -> Optional[DistributedContext]:
127236
"""
128237
:param Config|None config:
129238
:returns: the global context if Torch distributed is enabled, or None otherwise.
130-
If we did not setup the context yet, it will automatically create it.
239+
If we did not set up the context yet, it will automatically create it.
131240
"""
132241
global _is_set_up, _ctx
133242
if _is_set_up:
@@ -155,7 +264,7 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False):
155264

156265
if sync_on_cpu:
157266
for param in module.parameters():
158-
# Separately move each param to CPU (instead of the whole module), to safe CPU memory.
267+
# Separately move each param to CPU (instead of the whole module), to save CPU memory.
159268
param_cpu = param.to(torch.device("cpu"))
160269
# On CPU, we are likely using Gloo, and Gloo does not support AVG
161270
dist.all_reduce(param_cpu.data, op=dist.ReduceOp.SUM)
@@ -166,12 +275,11 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False):
166275
if dist.get_backend() == "gloo":
167276
# Gloo does not support AVG
168277
reduce_op = dist.ReduceOp.SUM
278+
elif hasattr(dist.ReduceOp, "AVG"):
279+
reduce_op = dist.ReduceOp.AVG
169280
else:
170-
if hasattr(dist.ReduceOp, "AVG"):
171-
reduce_op = dist.ReduceOp.AVG
172-
else:
173-
# Older PyTorch versions do not have ReduceOp.AVG.
174-
reduce_op = dist.ReduceOp.SUM
281+
# Older PyTorch versions do not have ReduceOp.AVG.
282+
reduce_op = dist.ReduceOp.SUM
175283

176284
for param in module.parameters():
177285
dist.all_reduce(param.data, op=reduce_op)

0 commit comments

Comments
 (0)