Skip to content

Commit 1d85ba4

Browse files
committed
[WIP] transform Stochastic into StochasticState
1 parent 477f19f commit 1d85ba4

2 files changed

Lines changed: 13 additions & 22 deletions

File tree

src/torchjd/aggregation/_mixins.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,17 @@ def reset(self) -> None:
2222
"""Resets the internal state :math:`s_0`."""
2323

2424

25-
class Stochastic(Stateful, ABC):
25+
class StochasticState(Stateful):
2626
r"""
27-
Stateful mixin that represents mappings that have inherent randomness.
27+
State respresenting stochasticity.
2828
29-
Internally, a ``Stochastic`` mapping holds a :class:`torch.Generator` that serves as an
30-
independent random number stream. Implementing classes must pass this generator to all torch
31-
random functions via their ``generator`` argument, e.g.:
32-
33-
.. code-block:: python
34-
35-
torch.rand(n, generator=self.generator)
36-
torch.randn(n, generator=self.generator)
37-
torch.randperm(n, generator=self.generator)
29+
Internally, a ``StochasticState`` mapping holds a :class:`torch.Generator` that serves as an
30+
independent random number stream.
3831
3932
:param seed: Seed for the internal :class:`torch.Generator`. If ``None``, a seed is drawn
4033
from the global PyTorch RNG to fork an independent stream.
4134
:param generator: An existing :class:`torch.Generator` to share, typically from a companion
42-
:class:`Stochastic` instance (e.g. a :class:`Weighting` sharing the generator of its
43-
:class:`Aggregator`). Mutually exclusive with ``seed``.
35+
:class:`StochasticState` instance. Mutually exclusive with ``seed``.
4436
"""
4537

4638
def __init__(self, seed: int | None = None, generator: torch.Generator | None = None) -> None:

src/torchjd/aggregation/_pcgrad.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
from torch import Tensor
55

66
from torchjd._linalg import PSDMatrix
7+
from torchjd.aggregation import Stateful
8+
from torchjd.aggregation._mixins import StochasticState
79

810
from ._aggregator_bases import GramianWeightedAggregator
9-
from ._mixins import Stochastic
1011
from ._utils.non_differentiable import raise_non_differentiable_error
1112
from ._weighting_bases import Weighting
1213

1314

14-
class PCGradWeighting(Weighting[PSDMatrix], Stochastic):
15+
class PCGradWeighting(Weighting[PSDMatrix], Stateful):
1516
"""
1617
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
1718
:class:`~torchjd.aggregation.PCGrad`.
@@ -21,8 +22,8 @@ class PCGradWeighting(Weighting[PSDMatrix], Stochastic):
2122
"""
2223

2324
def __init__(self, seed: int | None = None) -> None:
24-
Weighting.__init__(self)
25-
Stochastic.__init__(self, seed=seed)
25+
super().__init__()
26+
self.state = StochasticState(seed=seed)
2627

2728
def forward(self, gramian: PSDMatrix, /) -> Tensor:
2829
# Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration
@@ -35,7 +36,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
3536
weights = torch.zeros(dimension, device=cpu, dtype=dtype)
3637

3738
for i in range(dimension):
38-
permutation = torch.randperm(dimension, generator=self.generator)
39+
permutation = torch.randperm(dimension, generator=self.state.generator)
3940
current_weights = torch.zeros(dimension, device=cpu, dtype=dtype)
4041
current_weights[i] = 1.0
4142

@@ -54,7 +55,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
5455
return weights.to(device)
5556

5657

57-
class PCGrad(GramianWeightedAggregator, Stochastic):
58+
class PCGrad(GramianWeightedAggregator, Stateful):
5859
"""
5960
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in algorithm 1 of
6061
`Gradient Surgery for Multi-Task Learning <https://arxiv.org/pdf/2001.06782.pdf>`_.
@@ -64,9 +65,7 @@ class PCGrad(GramianWeightedAggregator, Stochastic):
6465
"""
6566

6667
def __init__(self, seed: int | None = None) -> None:
67-
weighting = PCGradWeighting(seed=seed)
68-
GramianWeightedAggregator.__init__(self, weighting)
69-
Stochastic.__init__(self, generator=weighting.generator)
68+
super().__init__(PCGradWeighting(seed=seed))
7069

7170
# This prevents running into a RuntimeError due to modifying stored tensors in place.
7271
self.register_full_backward_pre_hook(raise_non_differentiable_error)

0 commit comments

Comments
 (0)