Skip to content

Commit 4fb16da

Browse files
committed
Add CAGrad setters
1 parent 481bf67 commit 4fb16da

1 file changed

Lines changed: 59 additions & 36 deletions

File tree

src/torchjd/aggregation/_cagrad.py

Lines changed: 59 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,38 +18,6 @@
1818
from ._utils.non_differentiable import raise_non_differentiable_error
1919

2020

21-
class CAGrad(GramianWeightedAggregator):
22-
"""
23-
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of
24-
`Conflict-Averse Gradient Descent for Multi-task Learning
25-
<https://arxiv.org/pdf/2110.14048.pdf>`_.
26-
27-
:param c: The scale of the radius of the ball constraint.
28-
:param norm_eps: A small value to avoid division by zero when normalizing.
29-
30-
.. note::
31-
This aggregator is not installed by default. When not installed, trying to import it should
32-
result in the following error:
33-
``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``.
34-
To install it, use ``pip install "torchjd[cagrad]"``.
35-
"""
36-
37-
def __init__(self, c: float, norm_eps: float = 0.0001) -> None:
38-
super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps))
39-
self._c = c
40-
self._norm_eps = norm_eps
41-
42-
# This prevents considering the computed weights as constant w.r.t. the matrix.
43-
self.register_full_backward_pre_hook(raise_non_differentiable_error)
44-
45-
def __repr__(self) -> str:
46-
return f"{self.__class__.__name__}(c={self._c}, norm_eps={self._norm_eps})"
47-
48-
def __str__(self) -> str:
49-
c_str = str(self._c).rstrip("0")
50-
return f"CAGrad{c_str}"
51-
52-
5321
class CAGradWeighting(Weighting[PSDMatrix]):
5422
"""
5523
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
@@ -69,13 +37,22 @@ class CAGradWeighting(Weighting[PSDMatrix]):
6937

7038
def __init__(self, c: float, norm_eps: float = 0.0001) -> None:
7139
super().__init__()
72-
73-
if c < 0.0:
74-
raise ValueError(f"Parameter `c` should be a non-negative float. Found `c = {c}`.")
75-
7640
self.c = c
7741
self.norm_eps = norm_eps
7842

43+
@property
44+
def c(self) -> float:
45+
return self._c
46+
47+
@c.setter
48+
def c(self, value: float) -> None:
49+
if value < 0.0:
50+
raise ValueError(
51+
f"Parameter `value` should be a non-negative float. Found `value = {value}`."
52+
)
53+
54+
self._c = value
55+
7956
def forward(self, gramian: PSDMatrix, /) -> Tensor:
8057
U, S, _ = torch.svd(normalize(gramian, self.norm_eps))
8158

@@ -104,3 +81,49 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
10481
weights = torch.from_numpy(weight_array).to(device=gramian.device, dtype=gramian.dtype)
10582

10683
return weights
84+
85+
86+
class CAGrad(GramianWeightedAggregator[CAGradWeighting]):
87+
"""
88+
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of
89+
`Conflict-Averse Gradient Descent for Multi-task Learning
90+
<https://arxiv.org/pdf/2110.14048.pdf>`_.
91+
92+
:param c: The scale of the radius of the ball constraint.
93+
:param norm_eps: A small value to avoid division by zero when normalizing.
94+
95+
.. note::
96+
This aggregator is not installed by default. When not installed, trying to import it should
97+
result in the following error:
98+
``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``.
99+
To install it, use ``pip install "torchjd[cagrad]"``.
100+
"""
101+
102+
def __init__(self, c: float, norm_eps: float = 0.0001) -> None:
103+
super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps))
104+
self._c = c
105+
self._norm_eps = norm_eps
106+
107+
# This prevents considering the computed weights as constant w.r.t. the matrix.
108+
self.register_full_backward_pre_hook(raise_non_differentiable_error)
109+
110+
@property
111+
def c(self) -> float:
112+
return self._c
113+
114+
@c.setter
115+
def c(self, value: float) -> None:
116+
if value < 0.0:
117+
raise ValueError(
118+
f"Parameter `value` should be a non-negative float. Found `value = {value}`."
119+
)
120+
121+
self._c = value
122+
self.gramian_weighting.c = value
123+
124+
def __repr__(self) -> str:
125+
return f"{self.__class__.__name__}(c={self._c}, norm_eps={self._norm_eps})"
126+
127+
def __str__(self) -> str:
128+
c_str = str(self._c).rstrip("0")
129+
return f"CAGrad{c_str}"

0 commit comments

Comments
 (0)