Skip to content

Commit a6735c0

Browse files
committed
Make GradVac, GradDrop, PCGrad and Random Stochastic (borken).
1 parent 56ea02a commit a6735c0

File tree

5 files changed

+109
-32
lines changed

5 files changed

+109
-32
lines changed

src/torchjd/aggregation/_graddrop.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
from torchjd._linalg import Matrix
77

88
from ._aggregator_bases import Aggregator
9+
from ._mixins import Stochastic
910
from ._utils.non_differentiable import raise_non_differentiable_error
1011

1112

1213
def _identity(P: Tensor) -> Tensor:
1314
return P
1415

1516

16-
class GradDrop(Aggregator):
17+
class GradDrop(Aggregator, Stochastic):
1718
"""
1819
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` that applies the gradient combination
1920
steps from GradDrop, as defined in lines 10 to 15 of Algorithm 1 of `Just Pick a Sign:
@@ -24,16 +25,21 @@ class GradDrop(Aggregator):
2425
increasing. Defaults to identity.
2526
:param leak: The tensor of leak values, determining how much each row is allowed to leak
2627
through. Defaults to None, which means no leak.
28+
:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
29+
the global PyTorch RNG to fork an independent stream.
2730
"""
2831

29-
def __init__(self, f: Callable = _identity, leak: Tensor | None = None) -> None:
32+
def __init__(
33+
self, f: Callable = _identity, leak: Tensor | None = None, seed: int | None = None
34+
) -> None:
3035
if leak is not None and leak.dim() != 1:
3136
raise ValueError(
3237
"Parameter `leak` should be a 1-dimensional tensor. Found `leak.shape = "
3338
f"{leak.shape}`.",
3439
)
3540

36-
super().__init__()
41+
Aggregator.__init__(self)
42+
Stochastic.__init__(self, seed=seed)
3743
self.f = f
3844
self.leak = leak
3945

@@ -50,7 +56,7 @@ def forward(self, matrix: Matrix, /) -> Tensor:
5056

5157
P = 0.5 * (torch.ones_like(matrix[0]) + matrix.sum(dim=0) / matrix.abs().sum(dim=0))
5258
fP = self.f(P)
53-
U = torch.rand(P.shape, dtype=matrix.dtype, device=matrix.device)
59+
U = torch.rand(P.shape, dtype=matrix.dtype, device=matrix.device, generator=self.generator)
5460

5561
vector = torch.zeros_like(matrix[0])
5662
for i in range(len(matrix)):

src/torchjd/aggregation/_gradvac.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@
66
from torch import Tensor
77

88
from torchjd._linalg import PSDMatrix
9-
from torchjd.aggregation._mixins import Stateful
9+
from torchjd.aggregation._mixins import Stochastic
1010

1111
from ._aggregator_bases import GramianWeightedAggregator
1212
from ._utils.non_differentiable import raise_non_differentiable_error
1313
from ._weighting_bases import Weighting
1414

1515

16-
class GradVac(GramianWeightedAggregator, Stateful):
16+
class GradVac(GramianWeightedAggregator, Stochastic):
1717
r"""
18-
:class:`~torchjd.aggregation._mixins.Stateful`
18+
:class:`~torchjd.aggregation._mixins.Stochastic`
1919
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of
2020
Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task
2121
Optimization in Massively Multilingual Models (ICLR 2021 Spotlight)
@@ -35,16 +35,14 @@ class GradVac(GramianWeightedAggregator, Stateful):
3535
3636
:param beta: EMA decay for :math:`\hat{\phi}`.
3737
:param eps: Small non-negative constant added to denominators.
38-
39-
.. note::
40-
For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently
41-
using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if
42-
you need reproducibility.
38+
:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
39+
the global PyTorch RNG to fork an independent stream.
4340
"""
4441

45-
def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
46-
weighting = GradVacWeighting(beta=beta, eps=eps)
47-
super().__init__(weighting)
42+
def __init__(self, beta: float = 0.5, eps: float = 1e-8, seed: int | None = None) -> None:
43+
weighting = GradVacWeighting(beta=beta, eps=eps, seed=seed)
44+
GramianWeightedAggregator.__init__(self, weighting)
45+
Stochastic.__init__(self, generator=weighting.generator)
4846
self._gradvac_weighting = weighting
4947
self.register_full_backward_pre_hook(raise_non_differentiable_error)
5048

@@ -65,17 +63,18 @@ def eps(self, value: float) -> None:
6563
self._gradvac_weighting.eps = value
6664

6765
def reset(self) -> None:
68-
"""Clears EMA state so the next forward starts from zero targets."""
66+
"""Resets the random number generator and clears the EMA state."""
6967

68+
Stochastic.reset(self)
7069
self._gradvac_weighting.reset()
7170

7271
def __repr__(self) -> str:
7372
return f"GradVac(beta={self.beta!r}, eps={self.eps!r})"
7473

7574

76-
class GradVacWeighting(Weighting[PSDMatrix], Stateful):
75+
class GradVacWeighting(Weighting[PSDMatrix], Stochastic):
7776
r"""
78-
:class:`~torchjd.aggregation._mixins.Stateful`
77+
:class:`~torchjd.aggregation._mixins.Stochastic`
7978
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
8079
:class:`~torchjd.aggregation.GradVac`.
8180
@@ -97,10 +96,13 @@ class GradVacWeighting(Weighting[PSDMatrix], Stateful):
9796
9897
:param beta: EMA decay for :math:`\hat{\phi}`.
9998
:param eps: Small non-negative constant added to denominators.
99+
:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
100+
the global PyTorch RNG to fork an independent stream.
100101
"""
101102

102-
def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
103-
super().__init__()
103+
def __init__(self, beta: float = 0.5, eps: float = 1e-8, seed: int | None = None) -> None:
104+
Weighting.__init__(self)
105+
Stochastic.__init__(self, seed=seed)
104106
if not (0.0 <= beta <= 1.0):
105107
raise ValueError(f"Parameter `beta` must be in [0, 1]. Found beta={beta!r}.")
106108
if eps < 0.0:
@@ -132,8 +134,9 @@ def eps(self, value: float) -> None:
132134
self._eps = value
133135

134136
def reset(self) -> None:
135-
"""Clears EMA state so the next forward starts from zero targets."""
137+
"""Resets the random number generator and clears the EMA state."""
136138

139+
Stochastic.reset(self)
137140
self._phi_t = None
138141
self._state_key = None
139142

@@ -161,7 +164,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
161164
cG = C[i] @ G
162165

163166
others = [j for j in range(m) if j != i]
164-
perm = torch.randperm(len(others))
167+
perm = torch.randperm(len(others), generator=self.generator)
165168
shuffled_js = [others[idx] for idx in perm.tolist()]
166169

167170
for j in shuffled_js:

src/torchjd/aggregation/_mixins.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from abc import ABC, abstractmethod
22

3+
import torch
4+
35

46
class Stateful(ABC):
57
r"""
@@ -18,3 +20,41 @@ class Stateful(ABC):
1820
@abstractmethod
1921
def reset(self) -> None:
2022
"""Resets the internal state :math:`s_0`."""
23+
24+
25+
class Stochastic(Stateful, ABC):
26+
r"""
27+
Stateful mixin that represents mappings that have inherent randomness.
28+
29+
Internally, a ``Stochastic`` mapping holds a :class:`torch.Generator` that serves as an
30+
independent random number stream. Implementing classes must pass this generator to all torch
31+
random functions via their ``generator`` argument, e.g.:
32+
33+
.. code-block:: python
34+
35+
torch.rand(n, generator=self.generator)
36+
torch.randn(n, generator=self.generator)
37+
torch.randperm(n, generator=self.generator)
38+
39+
:param seed: Seed for the internal :class:`torch.Generator`. If ``None``, a seed is drawn
40+
from the global PyTorch RNG to fork an independent stream.
41+
:param generator: An existing :class:`torch.Generator` to share, typically from a companion
42+
:class:`Stochastic` instance (e.g. a :class:`Weighting` sharing the generator of its
43+
:class:`Aggregator`). Mutually exclusive with ``seed``.
44+
"""
45+
46+
def __init__(self, seed: int | None = None, generator: torch.Generator | None = None) -> None:
47+
if generator is not None and seed is not None:
48+
raise ValueError("Parameters `seed` and `generator` are mutually exclusive.")
49+
if generator is not None:
50+
self.generator = generator
51+
else:
52+
self.generator = torch.Generator()
53+
if seed is None:
54+
seed = int(torch.randint(0, 2**62, size=(1,), dtype=torch.int64).item())
55+
self.generator.manual_seed(seed)
56+
self._initial_rng_state = self.generator.get_state()
57+
58+
def reset(self) -> None:
59+
"""Resets the random number generator to its initial state."""
60+
self.generator.set_state(self._initial_rng_state)

src/torchjd/aggregation/_pcgrad.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,42 @@
66
from torchjd._linalg import PSDMatrix
77

88
from ._aggregator_bases import GramianWeightedAggregator
9+
from ._mixins import Stochastic
910
from ._utils.non_differentiable import raise_non_differentiable_error
1011
from ._weighting_bases import Weighting
1112

1213

13-
class PCGrad(GramianWeightedAggregator):
14+
class PCGrad(GramianWeightedAggregator, Stochastic):
1415
"""
1516
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in algorithm 1 of
1617
`Gradient Surgery for Multi-Task Learning <https://arxiv.org/pdf/2001.06782.pdf>`_.
18+
19+
:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
20+
the global PyTorch RNG to fork an independent stream.
1721
"""
1822

19-
def __init__(self) -> None:
20-
super().__init__(PCGradWeighting())
23+
def __init__(self, seed: int | None = None) -> None:
24+
weighting = PCGradWeighting(seed=seed)
25+
GramianWeightedAggregator.__init__(self, weighting)
26+
Stochastic.__init__(self, generator=weighting.generator)
2127

2228
# This prevents running into a RuntimeError due to modifying stored tensors in place.
2329
self.register_full_backward_pre_hook(raise_non_differentiable_error)
2430

2531

26-
class PCGradWeighting(Weighting[PSDMatrix]):
32+
class PCGradWeighting(Weighting[PSDMatrix], Stochastic):
2733
"""
2834
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
2935
:class:`~torchjd.aggregation.PCGrad`.
36+
37+
:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
38+
the global PyTorch RNG to fork an independent stream.
3039
"""
3140

41+
def __init__(self, seed: int | None = None) -> None:
42+
Weighting.__init__(self)
43+
Stochastic.__init__(self, seed=seed)
44+
3245
def forward(self, gramian: PSDMatrix, /) -> Tensor:
3346
# Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration
3447
device = gramian.device
@@ -40,7 +53,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
4053
weights = torch.zeros(dimension, device=cpu, dtype=dtype)
4154

4255
for i in range(dimension):
43-
permutation = torch.randperm(dimension)
56+
permutation = torch.randperm(dimension, generator=self.generator)
4457
current_weights = torch.zeros(dimension, device=cpu, dtype=dtype)
4558
current_weights[i] = 1.0
4659

src/torchjd/aggregation/_random.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,43 @@
55
from torchjd._linalg import Matrix
66

77
from ._aggregator_bases import WeightedAggregator
8+
from ._mixins import Stochastic
89
from ._weighting_bases import Weighting
910

1011

11-
class Random(WeightedAggregator):
12+
class Random(WeightedAggregator, Stochastic):
1213
"""
1314
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` that computes a random combination of
1415
the rows of the provided matrices, as defined in algorithm 2 of `Reasonable Effectiveness of
1516
Random Weighting: A Litmus Test for Multi-Task Learning
1617
<https://arxiv.org/pdf/2111.10603.pdf>`_.
18+
19+
:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
20+
the global PyTorch RNG to fork an independent stream.
1721
"""
1822

19-
def __init__(self) -> None:
20-
super().__init__(RandomWeighting())
23+
def __init__(self, seed: int | None = None) -> None:
24+
weighting = RandomWeighting(seed=seed)
25+
WeightedAggregator.__init__(self, weighting)
26+
Stochastic.__init__(self, generator=weighting.generator)
2127

2228

23-
class RandomWeighting(Weighting[Matrix]):
29+
class RandomWeighting(Weighting[Matrix], Stochastic):
2430
"""
2531
:class:`~torchjd.aggregation._weighting_bases.Weighting` that generates positive random weights
2632
at each call.
33+
34+
:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
35+
the global PyTorch RNG to fork an independent stream.
2736
"""
2837

38+
def __init__(self, seed: int | None = None) -> None:
39+
Weighting.__init__(self)
40+
Stochastic.__init__(self, seed=seed)
41+
2942
def forward(self, matrix: Tensor, /) -> Tensor:
30-
random_vector = torch.randn(matrix.shape[0], device=matrix.device, dtype=matrix.dtype)
43+
random_vector = torch.randn(
44+
matrix.shape[0], device=matrix.device, dtype=matrix.dtype, generator=self.generator
45+
)
3146
weights = F.softmax(random_vector, dim=-1)
3247
return weights

0 commit comments

Comments
 (0)