Skip to content

Commit df450a7

Browse files
committed
Fix mistakes when resolving merge conflict
1 parent b2541f4 commit df450a7

3 files changed

Lines changed: 26 additions & 27 deletions

File tree

src/torchjd/aggregation/_gradvac.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,9 @@ def _ensure_state(self, m: int, dtype: torch.dtype) -> None:
137137
self._state_key = key
138138

139139

140-
class GradVac(GramianWeightedAggregator, Stateful):
140+
class GradVac(GramianWeightedAggregator, Stochastic):
141141
r"""
142-
:class:`~torchjd.aggregation._mixins.Stateful`
142+
:class:`~torchjd.aggregation._mixins.Stochastic`
143143
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of
144144
Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task
145145
Optimization in Massively Multilingual Models (ICLR 2021 Spotlight)
@@ -159,22 +159,14 @@ class GradVac(GramianWeightedAggregator, Stateful):
159159
160160
:param beta: EMA decay for :math:`\hat{\phi}`.
161161
:param eps: Small non-negative constant added to denominators.
162-
163-
.. note::
164-
For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently
165-
using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if
166-
you need reproducibility.
167-
168-
.. note::
169-
To apply GradVac with the `whole_model`, `enc_dec`, `all_layer` or `all_matrix` grouping
170-
strategy, please refer to the :doc:`Grouping </examples/grouping>` examples.
162+
:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
163+
the global PyTorch RNG to fork an independent stream.
171164
"""
172165

173-
gramian_weighting: GradVacWeighting
174-
175-
def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
176-
weighting = GradVacWeighting(beta=beta, eps=eps)
177-
super().__init__(weighting)
166+
def __init__(self, beta: float = 0.5, eps: float = 1e-8, seed: int | None = None) -> None:
167+
weighting = GradVacWeighting(beta=beta, eps=eps, seed=seed)
168+
GramianWeightedAggregator.__init__(self, weighting)
169+
Stochastic.__init__(self, generator=weighting.generator)
178170
self._gradvac_weighting = weighting
179171
self.register_full_backward_pre_hook(raise_non_differentiable_error)
180172

@@ -195,8 +187,9 @@ def eps(self, value: float) -> None:
195187
self._gradvac_weighting.eps = value
196188

197189
def reset(self) -> None:
198-
"""Clears EMA state so the next forward starts from zero targets."""
190+
"""Resets the random number generator and clears the EMA state."""
199191

192+
Stochastic.reset(self)
200193
self._gradvac_weighting.reset()
201194

202195
def __repr__(self) -> str:

src/torchjd/aggregation/_pcgrad.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,19 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
5454
return weights.to(device)
5555

5656

57-
class PCGrad(GramianWeightedAggregator):
57+
class PCGrad(GramianWeightedAggregator, Stochastic):
5858
"""
5959
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in algorithm 1 of
6060
`Gradient Surgery for Multi-Task Learning <https://arxiv.org/pdf/2001.06782.pdf>`_.
61-
"""
6261
63-
gramian_weighting: PCGradWeighting
62+
:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
63+
the global PyTorch RNG to fork an independent stream.
64+
"""
6465

65-
def __init__(self) -> None:
66-
super().__init__(PCGradWeighting())
66+
def __init__(self, seed: int | None = None) -> None:
67+
weighting = PCGradWeighting(seed=seed)
68+
GramianWeightedAggregator.__init__(self, weighting)
69+
Stochastic.__init__(self, generator=weighting.generator)
6770

6871
# This prevents running into a RuntimeError due to modifying stored tensors in place.
6972
self.register_full_backward_pre_hook(raise_non_differentiable_error)

src/torchjd/aggregation/_random.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,18 @@ def forward(self, matrix: Tensor, /) -> Tensor:
3030
return weights
3131

3232

33-
class Random(WeightedAggregator):
33+
class Random(WeightedAggregator, Stochastic):
3434
"""
3535
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` that computes a random combination of
3636
the rows of the provided matrices, as defined in algorithm 2 of `Reasonable Effectiveness of
3737
Random Weighting: A Litmus Test for Multi-Task Learning
3838
<https://arxiv.org/pdf/2111.10603.pdf>`_.
39-
"""
4039
41-
weighting: RandomWeighting
40+
:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
41+
the global PyTorch RNG to fork an independent stream.
42+
"""
4243

43-
def __init__(self) -> None:
44-
super().__init__(RandomWeighting())
44+
def __init__(self, seed: int | None = None) -> None:
45+
weighting = RandomWeighting(seed=seed)
46+
WeightedAggregator.__init__(self, weighting)
47+
Stochastic.__init__(self, generator=weighting.generator)

0 commit comments

Comments
 (0)