1818from ._utils .non_differentiable import raise_non_differentiable_error
1919
2020
21- class CAGrad (GramianWeightedAggregator ):
22- """
23- :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of
24- `Conflict-Averse Gradient Descent for Multi-task Learning
25- <https://arxiv.org/pdf/2110.14048.pdf>`_.
26-
27- :param c: The scale of the radius of the ball constraint.
28- :param norm_eps: A small value to avoid division by zero when normalizing.
29-
30- .. note::
31- This aggregator is not installed by default. When not installed, trying to import it should
32- result in the following error:
33- ``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``.
34- To install it, use ``pip install "torchjd[cagrad]"``.
35- """
36-
37- def __init__ (self , c : float , norm_eps : float = 0.0001 ) -> None :
38- super ().__init__ (CAGradWeighting (c = c , norm_eps = norm_eps ))
39- self ._c = c
40- self ._norm_eps = norm_eps
41-
42- # This prevents considering the computed weights as constant w.r.t. the matrix.
43- self .register_full_backward_pre_hook (raise_non_differentiable_error )
44-
45- def __repr__ (self ) -> str :
46- return f"{ self .__class__ .__name__ } (c={ self ._c } , norm_eps={ self ._norm_eps } )"
47-
48- def __str__ (self ) -> str :
49- c_str = str (self ._c ).rstrip ("0" )
50- return f"CAGrad{ c_str } "
51-
52-
5321class CAGradWeighting (Weighting [PSDMatrix ]):
5422 """
5523 :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
@@ -69,13 +37,22 @@ class CAGradWeighting(Weighting[PSDMatrix]):
6937
7038 def __init__ (self , c : float , norm_eps : float = 0.0001 ) -> None :
7139 super ().__init__ ()
72-
73- if c < 0.0 :
74- raise ValueError (f"Parameter `c` should be a non-negative float. Found `c = { c } `." )
75-
7640 self .c = c
7741 self .norm_eps = norm_eps
7842
43+ @property
44+ def c (self ) -> float :
45+ return self ._c
46+
47+ @c .setter
48+ def c (self , value : float ) -> None :
49+ if value < 0.0 :
50+ raise ValueError (
51+ f"Parameter `value` should be a non-negative float. Found `value = { value } `."
52+ )
53+
54+ self ._c = value
55+
7956 def forward (self , gramian : PSDMatrix , / ) -> Tensor :
8057 U , S , _ = torch .svd (normalize (gramian , self .norm_eps ))
8158
@@ -104,3 +81,49 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
10481 weights = torch .from_numpy (weight_array ).to (device = gramian .device , dtype = gramian .dtype )
10582
10683 return weights
84+
85+
86+ class CAGrad (GramianWeightedAggregator [CAGradWeighting ]):
87+ """
88+ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of
89+ `Conflict-Averse Gradient Descent for Multi-task Learning
90+ <https://arxiv.org/pdf/2110.14048.pdf>`_.
91+
92+ :param c: The scale of the radius of the ball constraint.
93+ :param norm_eps: A small value to avoid division by zero when normalizing.
94+
95+ .. note::
96+ This aggregator is not installed by default. When not installed, trying to import it should
97+ result in the following error:
98+ ``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``.
99+ To install it, use ``pip install "torchjd[cagrad]"``.
100+ """
101+
102+ def __init__ (self , c : float , norm_eps : float = 0.0001 ) -> None :
103+ super ().__init__ (CAGradWeighting (c = c , norm_eps = norm_eps ))
104+ self ._c = c
105+ self ._norm_eps = norm_eps
106+
107+ # This prevents considering the computed weights as constant w.r.t. the matrix.
108+ self .register_full_backward_pre_hook (raise_non_differentiable_error )
109+
110+ @property
111+ def c (self ) -> float :
112+ return self ._c
113+
114+ @c .setter
115+ def c (self , value : float ) -> None :
116+ if value < 0.0 :
117+ raise ValueError (
118+ f"Parameter `value` should be a non-negative float. Found `value = { value } `."
119+ )
120+
121+ self ._c = value
122+ self .gramian_weighting .c = value
123+
124+ def __repr__ (self ) -> str :
125+ return f"{ self .__class__ .__name__ } (c={ self ._c } , norm_eps={ self ._norm_eps } )"
126+
127+ def __str__ (self ) -> str :
128+ c_str = str (self ._c ).rstrip ("0" )
129+ return f"CAGrad{ c_str } "
0 commit comments