33"""
44
55from __future__ import annotations
6- from typing import Optional , Any , Dict
6+ from abc import abstractmethod , ABC
7+ import logging
8+ import numpy
79import os
810import socket
9- import logging
11+ from typing import Callable , Optional , Any , Dict , Type , Union
1012
1113import torch
1214from torch .nn .parallel import DistributedDataParallel
1315
14- from returnn .config import Config
15- from returnn .util .basic import CollectionReadCheckCovered
16+ from returnn .util .basic import CollectionReadCheckCovered , OptionalNotImplementedError
1617
1718_logger = logging .getLogger ("returnn.torch.distributed" )
1819
1920
21+ class ParamSynchronizer (ABC ):
22+ """
23+ Custom parameter synchronization primitive.
24+
25+ Contains a callback that is called after every train step to synchronize model parameters
26+ across processes/nodes.
27+ """
28+
29+ @abstractmethod
30+ def __init__ (self , * , rank : int , size : int , local_rank : int , local_size : int , ** kwargs ):
31+ """
32+ `__init__` called after the default global process group is created.
33+ Can be used to initialize any additional custom process (sub)groups.
34+
35+ Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatibility.
36+
37+ :param rank: global rank of the current process across all nodes
38+ :param size: global world size across all nodes
39+ :param local_rank: local rank of the current process on the current node
40+ :param local_rank: local world size on the current node
41+ :param kwargs: any additional kwargs
42+ """
43+ super ().__init__ ()
44+
45+ self .rank = rank
46+ self .size = size
47+ self .local_rank = local_rank
48+ self .local_size = local_size
49+
50+ def make_distributed_model (self , * , module : torch .nn .Module , ** kwargs ) -> DistributedDataParallel :
51+ """
52+ Creates an associated `DistributedDataParallel` for the given module for gradient synchronization.
53+
54+ This function can be left unimplemented if no gradient synchronization is done.
55+
56+ Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatibility.
57+ """
58+ raise OptionalNotImplementedError
59+
60+ @abstractmethod
61+ def step (self , * , module : torch .nn .Module , train_step_idx : int , ** kwargs ):
62+ """
63+ Parameter synchronization callback called after every train step with updated model parameters.
64+
65+ Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatibility.
66+
67+ :param module: the NN being trained
68+ :param train_step_idx: the current train step
69+ :param kwargs: any additional kwargs
70+ """
71+ raise NotImplementedError
72+
73+ def __call__ (self , * args , ** kwargs ):
74+ """forwards to :func:``step``"""
75+ return self .step (* args , ** kwargs )
76+
77+
2078class DistributedContext :
2179 """
2280 This class setups some helper functions for torch distributed training
@@ -26,6 +84,9 @@ def __init__(self, options: Dict[str, Any]):
2684 import torch .distributed as dist
2785
2886 self ._opts = CollectionReadCheckCovered (options )
87+ # Only used to generate forwards compatibility ensuring random kwargs, therefore
88+ # the seed is not important
89+ self ._rng = numpy .random .default_rng ()
2990
3091 # when no backend is specified, both gloo and nccl backends will be created
3192 # the gloo backend will be used for collectives with CPU tensors and
@@ -42,8 +103,13 @@ def __init__(self, options: Dict[str, Any]):
42103 % (socket .gethostname (), os .getpid (), self ._rank , self ._size , self ._local_rank , self ._local_size )
43104 )
44105
106+ self ._custom_sync_class : Optional [Union [Callable , Type [ParamSynchronizer ]]] = self ._opts .get (
107+ "synchronizer" , None
108+ )
109+ self ._custom_sync : Optional [Callable ] = None
45110 self ._reduce_type = self ._opts .get ("reduce_type" , "grad" )
46111 self ._param_sync_step : Optional [int ] = self ._opts .get ("param_sync_step" , None )
112+
47113 if self ._reduce_type == "param" :
48114 assert isinstance (self ._param_sync_step , int ) and self ._param_sync_step > 0 , (
49115 f"reduce_type param: param_sync_step must be a positive int,"
@@ -52,6 +118,23 @@ def __init__(self, options: Dict[str, Any]):
52118 _logger .info (f"reduce_type param: param_sync_step { self ._param_sync_step } " )
53119 elif self ._reduce_type == "grad" :
54120 _logger .info ("reduce_type grad" )
121+ elif self ._reduce_type == "custom" :
122+ if issubclass (self ._custom_sync_class , ParamSynchronizer ):
123+ self ._custom_sync = self ._custom_sync_class (
124+ rank = self ._rank ,
125+ size = self ._size ,
126+ local_rank = self ._local_rank ,
127+ local_size = self ._local_size ,
128+ ** {f"fwd_compatible_random_kwarg_{ self ._rng .integers (0 , 100 )} " : None },
129+ )
130+ elif isinstance (self ._custom_sync_class , Callable ):
131+ self ._custom_sync = self ._custom_sync_class
132+ else :
133+ raise ValueError (
134+ f"synchronizer must either be a callable or a class inheriting from { ParamSynchronizer .__name__ } "
135+ )
136+
137+ _logger .info (f"reduce_type custom: { type (self ._custom_sync )} " )
55138 else :
56139 raise ValueError (f"invalid reduce_type { self ._reduce_type !r} " )
57140
@@ -70,6 +153,8 @@ def _check_no_unknown_opts(self):
70153 self ._opts .get ("options" )
71154 if self ._reduce_type == "param" :
72155 self ._opts .get ("sync_on_cpu" )
156+ if self ._reduce_type == "custom" :
157+ self ._opts .get ("synchronizer" )
73158
74159 self ._opts .assert_all_read ()
75160
@@ -102,7 +187,24 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis
102187 """
103188 if self ._reduce_type == "param" :
104189 return None
105- assert self ._reduce_type == "grad"
190+ assert self ._reduce_type in ["custom" , "grad" ]
191+
192+ if self ._reduce_type == "custom" :
193+ assert isinstance (self ._custom_sync , (ParamSynchronizer , Callable ))
194+
195+ if isinstance (self ._custom_sync , ParamSynchronizer ):
196+ try :
197+ return self ._custom_sync .make_distributed_model (
198+ module = module , ** {f"fwd_compatible_random_kwarg_{ self ._rng .integers (0 , 100 )} " : None }
199+ )
200+ except OptionalNotImplementedError :
201+ pass
202+ else :
203+ # callable short form does not have support for DistributedDataParallel
204+ pass
205+
206+ return None
207+
106208 cls = self ._opts .get ("class" , DistributedDataParallel )
107209 if cls is not DistributedDataParallel :
108210 _logger .warning (f"Using custom class { cls } instead of DistributedDataParallel, might be unsupported." )
@@ -115,7 +217,14 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis
115217
116218 def step_after_param_update (self , * , module : torch .nn .Module , epoch_step_idx : int ):
117219 """one train step"""
118- if self ._reduce_type == "param" and ((epoch_step_idx % self ._param_sync_step ) == (self ._param_sync_step - 1 )):
220+ if self ._reduce_type == "custom" :
221+ with torch .no_grad (): # TODO: do we want this for all syncers?
222+ self ._custom_sync (
223+ module = module ,
224+ train_step_idx = epoch_step_idx ,
225+ ** {f"fwd_compatible_random_kwarg_{ self ._rng .integers (0 , 100 )} " : None },
226+ )
227+ elif self ._reduce_type == "param" and ((epoch_step_idx % self ._param_sync_step ) == (self ._param_sync_step - 1 )):
119228 _sync_params_avg (module = module , sync_on_cpu = self ._opts .get ("sync_on_cpu" , False ))
120229
121230
@@ -127,7 +236,7 @@ def get_ctx(config=None) -> Optional[DistributedContext]:
127236 """
128237 :param Config|None config:
129238 :returns: the global context if Torch distributed is enabled, or None otherwise.
130- If we did not setup the context yet, it will automatically create it.
239+ If we did not set up the context yet, it will automatically create it.
131240 """
132241 global _is_set_up , _ctx
133242 if _is_set_up :
@@ -155,7 +264,7 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False):
155264
156265 if sync_on_cpu :
157266 for param in module .parameters ():
158- # Separately move each param to CPU (instead of the whole module), to safe CPU memory.
267+ # Separately move each param to CPU (instead of the whole module), to save CPU memory.
159268 param_cpu = param .to (torch .device ("cpu" ))
160269 # On CPU, we are likely using Gloo, and Gloo does not support AVG
161270 dist .all_reduce (param_cpu .data , op = dist .ReduceOp .SUM )
@@ -166,12 +275,11 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False):
166275 if dist .get_backend () == "gloo" :
167276 # Gloo does not support AVG
168277 reduce_op = dist .ReduceOp .SUM
278+ elif hasattr (dist .ReduceOp , "AVG" ):
279+ reduce_op = dist .ReduceOp .AVG
169280 else :
170- if hasattr (dist .ReduceOp , "AVG" ):
171- reduce_op = dist .ReduceOp .AVG
172- else :
173- # Older PyTorch versions do not have ReduceOp.AVG.
174- reduce_op = dist .ReduceOp .SUM
281+ # Older PyTorch versions do not have ReduceOp.AVG.
282+ reduce_op = dist .ReduceOp .SUM
175283
176284 for param in module .parameters ():
177285 dist .all_reduce (param .data , op = reduce_op )
0 commit comments