Skip to content

Commit 1a2ae67

Browse files
committed
feat: Add STCH aggregator
Implement the Smooth Tchebycheff (STCH) scalarization algorithm from "Smooth Tchebycheff Scalarization for Multi-Objective Optimization" (https://arxiv.org/abs/2402.19078). The aggregator uses log-sum-exp (smooth maximum) to compute weights that focus more on poorly performing tasks. Key features: - mu parameter controls smoothness (smaller = harder max, larger = uniform) - Optional warmup_steps for computing nadir vector from gradient norms - reset() method for clearing state between experiments - Both STCH (Aggregator) and STCHWeighting (Weighting) classes provided
1 parent 2c2a58f commit 1a2ae67

5 files changed

Lines changed: 486 additions & 0 deletions

File tree

docs/source/docs/aggregation/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,6 @@ Abstract base classes
4242
nash_mtl.rst
4343
pcgrad.rst
4444
random.rst
45+
stch.rst
4546
sum.rst
4647
trimmed_mean.rst
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
:hide-toc:
2+
3+
STCH
4+
====
5+
6+
.. autoclass:: torchjd.aggregation.STCH
7+
:members:
8+
:undoc-members:
9+
:exclude-members: forward
10+
11+
.. autoclass:: torchjd.aggregation.STCHWeighting
12+
:members:
13+
:undoc-members:
14+
:exclude-members: forward

src/torchjd/aggregation/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
from ._mgda import MGDA, MGDAWeighting
7373
from ._pcgrad import PCGrad, PCGradWeighting
7474
from ._random import Random, RandomWeighting
75+
from ._stch import STCH, STCHWeighting
7576
from ._sum import Sum, SumWeighting
7677
from ._trimmed_mean import TrimmedMean
7778
from ._upgrad import UPGrad, UPGradWeighting
@@ -104,6 +105,8 @@
104105
"PCGradWeighting",
105106
"Random",
106107
"RandomWeighting",
108+
"STCH",
109+
"STCHWeighting",
107110
"Sum",
108111
"SumWeighting",
109112
"TrimmedMean",

src/torchjd/aggregation/_stch.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
# The code of this file was adapted from
2+
# https://github.com/Xi-L/STCH/blob/main/STCH_MTL/LibMTL/weighting/STCH.py.
3+
# It is therefore also subject to the following license.
4+
#
5+
# MIT License
6+
#
7+
# Copyright (c) 2024 Xi Lin
8+
#
9+
# Permission is hereby granted, free of charge, to any person obtaining a copy
10+
# of this software and associated documentation files (the "Software"), to deal
11+
# in the Software without restriction, including without limitation the rights
12+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13+
# copies of the Software, and to permit persons to whom the Software is
14+
# furnished to do so, subject to the following conditions:
15+
#
16+
# The above copyright notice and this permission notice shall be included in all
17+
# copies or substantial portions of the Software.
18+
#
19+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25+
# SOFTWARE.
26+
27+
import torch
28+
from torch import Tensor
29+
30+
from torchjd._linalg import PSDMatrix
31+
32+
from ._aggregator_bases import GramianWeightedAggregator
33+
from ._weighting_bases import Weighting
34+
35+
36+
class STCH(GramianWeightedAggregator):
37+
r"""
38+
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the Smooth Tchebycheff
39+
scalarization as proposed in `Smooth Tchebycheff Scalarization for Multi-Objective Optimization
40+
<https://arxiv.org/abs/2402.19078>`_.
41+
42+
This aggregator uses the log-sum-exp (smooth maximum) function to compute weights that focus
43+
more on poorly performing tasks (tasks with larger gradient norms). The ``mu`` parameter
44+
controls the smoothness: as ``mu`` approaches 0, the weights converge to a hard maximum
45+
(focusing entirely on the worst task); as ``mu`` increases, the weights approach uniform
46+
averaging.
47+
48+
:param mu: The smoothness parameter for the log-sum-exp. Smaller values give more weight to the
49+
worst-performing task. Must be positive.
50+
:param warmup_steps: Optional number of steps for the warmup phase. During warmup, gradient
51+
norms are accumulated to compute a nadir vector for normalization. If ``None`` (default),
52+
no warmup is performed and raw gradient norms are used directly.
53+
:param eps: A small value to avoid numerical issues in log computations.
54+
55+
.. warning::
56+
If ``warmup_steps`` is set, this aggregator becomes stateful. Its output will depend not
57+
only on the input matrix, but also on its internal state (previously seen matrices). It
58+
should be reset between experiments using the :meth:`reset` method.
59+
60+
.. note::
61+
The original STCH algorithm operates on loss values. This implementation adapts it for
62+
gradient-based aggregation using gradient norms (derived from the Gramian diagonal) as
63+
proxies for task performance.
64+
65+
Example
66+
-------
67+
68+
>>> from torch import tensor
69+
>>> from torchjd.aggregation import STCH
70+
>>>
71+
>>> A = STCH(mu=1.0)
72+
>>> J = tensor([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]])
73+
>>> A(J)
74+
tensor([1.8188, 1.0000, 1.0000])
75+
76+
.. note::
77+
This implementation was adapted from the `official implementation
78+
<https://github.com/Xi-L/STCH>`_.
79+
"""
80+
81+
def __init__(
82+
self,
83+
mu: float = 1.0,
84+
warmup_steps: int | None = None,
85+
eps: float = 1e-20,
86+
):
87+
if mu <= 0.0:
88+
raise ValueError(f"Parameter `mu` should be a positive float. Found `mu = {mu}`.")
89+
90+
if warmup_steps is not None and warmup_steps < 1:
91+
raise ValueError(
92+
f"Parameter `warmup_steps` should be a positive integer or None. "
93+
f"Found `warmup_steps = {warmup_steps}`."
94+
)
95+
96+
stch_weighting = STCHWeighting(mu=mu, warmup_steps=warmup_steps, eps=eps)
97+
super().__init__(stch_weighting)
98+
99+
self._mu = mu
100+
self._warmup_steps = warmup_steps
101+
self._eps = eps
102+
self._stch_weighting = stch_weighting
103+
104+
def reset(self) -> None:
105+
"""Resets the internal state of the algorithm (step counter and accumulated nadir)."""
106+
self._stch_weighting.reset()
107+
108+
def __repr__(self) -> str:
109+
return (
110+
f"{self.__class__.__name__}(mu={self._mu}, warmup_steps={self._warmup_steps}, "
111+
f"eps={self._eps})"
112+
)
113+
114+
def __str__(self) -> str:
115+
mu_str = str(self._mu).rstrip("0").rstrip(".")
116+
return f"STCH(mu={mu_str})"
117+
118+
119+
class STCHWeighting(Weighting[PSDMatrix]):
120+
r"""
121+
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
122+
:class:`~torchjd.aggregation.STCH`.
123+
124+
The weights are computed using the Smooth Tchebycheff scalarization formula:
125+
126+
.. math::
127+
128+
w_i = \frac{\exp\left(\frac{\log(g_i / z_i) - \max_j \log(g_j / z_j)}{\mu}\right)}
129+
{\sum_k \exp\left(\frac{\log(g_k / z_k) - \max_j \log(g_j / z_j)}{\mu}\right)}
130+
131+
where :math:`g_i` is the gradient norm for task :math:`i` (computed as :math:`\sqrt{G_{ii}}`
132+
from the Gramian), :math:`z_i` is the nadir value for task :math:`i`, and :math:`\mu` is the
133+
smoothness parameter.
134+
135+
:param mu: The smoothness parameter for the log-sum-exp. Must be positive.
136+
:param warmup_steps: Optional number of steps for the warmup phase. During warmup, gradient
137+
norms are accumulated to compute a nadir vector. If ``None``, no warmup is performed.
138+
:param eps: A small value to avoid numerical issues in log computations.
139+
140+
.. warning::
141+
If ``warmup_steps`` is set, this weighting becomes stateful. During warmup, it returns
142+
uniform weights while accumulating gradient norms. After warmup, the accumulated average
143+
is used as the nadir vector for normalization.
144+
"""
145+
146+
def __init__(
147+
self,
148+
mu: float = 1.0,
149+
warmup_steps: int | None = None,
150+
eps: float = 1e-20,
151+
):
152+
super().__init__()
153+
154+
if mu <= 0.0:
155+
raise ValueError(f"Parameter `mu` should be a positive float. Found `mu = {mu}`.")
156+
157+
if warmup_steps is not None and warmup_steps < 1:
158+
raise ValueError(
159+
f"Parameter `warmup_steps` should be a positive integer or None. "
160+
f"Found `warmup_steps = {warmup_steps}`."
161+
)
162+
163+
self.mu = mu
164+
self.warmup_steps = warmup_steps
165+
self.eps = eps
166+
167+
# Internal state for warmup
168+
self.step = 0
169+
self.nadir_accumulator: Tensor | None = None
170+
self.nadir_vector: Tensor | None = None
171+
172+
def reset(self) -> None:
173+
"""Resets the internal state of the algorithm."""
174+
self.step = 0
175+
self.nadir_accumulator = None
176+
self.nadir_vector = None
177+
178+
def forward(self, gramian: PSDMatrix) -> Tensor:
179+
device = gramian.device
180+
dtype = gramian.dtype
181+
m = gramian.shape[0]
182+
183+
# Compute gradient norms from Gramian diagonal (sqrt of diagonal)
184+
grad_norms = torch.sqrt(torch.diag(gramian).clamp(min=self.eps))
185+
186+
# Handle warmup phase if warmup_steps is set
187+
if self.warmup_steps is not None:
188+
if self.step < self.warmup_steps:
189+
# During warmup: accumulate gradient norms and return uniform weights
190+
if self.nadir_accumulator is None:
191+
self.nadir_accumulator = grad_norms.detach().clone()
192+
else:
193+
self.nadir_accumulator = self.nadir_accumulator.to(
194+
device=device, dtype=dtype
195+
) + grad_norms.detach()
196+
197+
self.step += 1
198+
199+
# Return uniform weights during warmup
200+
return torch.full(size=[m], fill_value=1.0 / m, device=device, dtype=dtype)
201+
202+
elif self.nadir_vector is None:
203+
# First step after warmup: compute nadir vector from accumulated values
204+
self.nadir_vector = self.nadir_accumulator / self.warmup_steps # type: ignore
205+
self.step += 1
206+
else:
207+
self.step += 1
208+
209+
# Normalize by nadir vector if available (after warmup)
210+
if self.nadir_vector is not None:
211+
nadir = self.nadir_vector.to(device=device, dtype=dtype)
212+
normalized = grad_norms / nadir.clamp(min=self.eps)
213+
else:
214+
normalized = grad_norms
215+
216+
# Apply log and compute smooth max weights using log-sum-exp trick for numerical stability
217+
log_normalized = torch.log(normalized + self.eps)
218+
max_log = torch.max(log_normalized)
219+
reg_log = (log_normalized - max_log) / self.mu
220+
221+
# Softmax weights
222+
exp_reg = torch.exp(reg_log)
223+
weights = exp_reg / exp_reg.sum()
224+
225+
return weights

0 commit comments

Comments
 (0)