44from torch import Tensor
55
66from torchjd ._linalg import PSDMatrix
7+ from torchjd .aggregation import Stateful
8+ from torchjd .aggregation ._mixins import StochasticState
79
810from ._aggregator_bases import GramianWeightedAggregator
9- from ._mixins import Stochastic
1011from ._utils .non_differentiable import raise_non_differentiable_error
1112from ._weighting_bases import Weighting
1213
1314
14- class PCGradWeighting (Weighting [PSDMatrix ], Stochastic ):
15+ class PCGradWeighting (Weighting [PSDMatrix ], Stateful ):
1516 """
1617 :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
1718 :class:`~torchjd.aggregation.PCGrad`.
@@ -21,8 +22,8 @@ class PCGradWeighting(Weighting[PSDMatrix], Stochastic):
2122 """
2223
2324 def __init__ (self , seed : int | None = None ) -> None :
24- Weighting .__init__ (self )
25- Stochastic . __init__ ( self , seed = seed )
25+ super () .__init__ ()
26+ self . state = StochasticState ( seed = seed )
2627
2728 def forward (self , gramian : PSDMatrix , / ) -> Tensor :
2829 # Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration
@@ -35,7 +36,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
3536 weights = torch .zeros (dimension , device = cpu , dtype = dtype )
3637
3738 for i in range (dimension ):
38- permutation = torch .randperm (dimension , generator = self .generator )
39+ permutation = torch .randperm (dimension , generator = self .state . generator )
3940 current_weights = torch .zeros (dimension , device = cpu , dtype = dtype )
4041 current_weights [i ] = 1.0
4142
@@ -54,7 +55,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
5455 return weights .to (device )
5556
5657
57- class PCGrad (GramianWeightedAggregator , Stochastic ):
58+ class PCGrad (GramianWeightedAggregator , Stateful ):
5859 """
5960 :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in algorithm 1 of
6061 `Gradient Surgery for Multi-Task Learning <https://arxiv.org/pdf/2001.06782.pdf>`_.
@@ -64,9 +65,7 @@ class PCGrad(GramianWeightedAggregator, Stochastic):
6465 """
6566
6667 def __init__ (self , seed : int | None = None ) -> None :
67- weighting = PCGradWeighting (seed = seed )
68- GramianWeightedAggregator .__init__ (self , weighting )
69- Stochastic .__init__ (self , generator = weighting .generator )
68+ super ().__init__ (PCGradWeighting (seed = seed ))
7069
7170 # This prevents running into a RuntimeError due to modifying stored tensors in place.
7271 self .register_full_backward_pre_hook (raise_non_differentiable_error )
0 commit comments