@@ -137,9 +137,9 @@ def _ensure_state(self, m: int, dtype: torch.dtype) -> None:
137137 self ._state_key = key
138138
139139
140- class GradVac (GramianWeightedAggregator , Stateful ):
140+ class GradVac (GramianWeightedAggregator , Stochastic ):
141141 r"""
142- :class:`~torchjd.aggregation._mixins.Stateful `
142+ :class:`~torchjd.aggregation._mixins.Stochastic `
143143 :class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of
144144 Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task
145145 Optimization in Massively Multilingual Models (ICLR 2021 Spotlight)
@@ -159,22 +159,14 @@ class GradVac(GramianWeightedAggregator, Stateful):
159159
160160 :param beta: EMA decay for :math:`\hat{\phi}`.
161161 :param eps: Small non-negative constant added to denominators.
162-
163- .. note::
164- For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently
165- using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if
166- you need reproducibility.
167-
168- .. note::
169- To apply GradVac with the `whole_model`, `enc_dec`, `all_layer` or `all_matrix` grouping
170- strategy, please refer to the :doc:`Grouping </examples/grouping>` examples.
162+ :param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
163+ the global PyTorch RNG to fork an independent stream.
171164 """
172165
173- gramian_weighting : GradVacWeighting
174-
175- def __init__ (self , beta : float = 0.5 , eps : float = 1e-8 ) -> None :
176- weighting = GradVacWeighting (beta = beta , eps = eps )
177- super ().__init__ (weighting )
166+ def __init__ (self , beta : float = 0.5 , eps : float = 1e-8 , seed : int | None = None ) -> None :
167+ weighting = GradVacWeighting (beta = beta , eps = eps , seed = seed )
168+ GramianWeightedAggregator .__init__ (self , weighting )
169+ Stochastic .__init__ (self , generator = weighting .generator )
178170 self ._gradvac_weighting = weighting
179171 self .register_full_backward_pre_hook (raise_non_differentiable_error )
180172
@@ -195,8 +187,9 @@ def eps(self, value: float) -> None:
195187 self ._gradvac_weighting .eps = value
196188
197189 def reset (self ) -> None :
198- """Clears EMA state so the next forward starts from zero targets ."""
190+ """Resets the random number generator and clears the EMA state ."""
199191
192+ Stochastic .reset (self )
200193 self ._gradvac_weighting .reset ()
201194
202195 def __repr__ (self ) -> str :
0 commit comments