Skip to content

Commit 5a6358f

Browse files
committed
feat(Aggregation): Add ExcessMTLWeighting
1 parent 1aed3c3 commit 5a6358f

7 files changed

Lines changed: 432 additions & 0 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ changelog does not include internal changes that do not affect the user.
1818
Algorithm Based on Decomposition](https://ieeexplore.ieee.org/document/4358754) (IEEE TEVC 2007), a
1919
`Scalarizer` that decomposes the values into a component along a preference direction and a
2020
penalized perpendicular component.
21+
- Added `ExcessMTLWeighting` from [Robust Multi-Task Learning with Excess Risks](https://proceedings.mlr.press/v235/he24n.html) (ICML 2024). It is a stateful `Weighting` that maintains task weights across calls via an exponentiated gradient update driven by per-task excess risk estimates. The excess risk is approximated using an AdaGrad-style diagonal Hessian. An optional `n_warmup_steps` parameter controls how many forward calls collect gradient statistics before weight updates begin.
2122

2223
## [0.15.0] - 2026-06-15
2324

NOTICES

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,34 @@ SOFTWARE.
143143

144144
-------------------------------------------------------------------------------
145145

146+
Project: ExcessMTL
147+
Source: https://github.com/uiuctml/ExcessMTL/blob/main/LibMTL/LibMTL/weighting/ExcessMTL.py
148+
Used in: src/torchjd/aggregation/_excess_mtl.py
149+
150+
MIT License
151+
152+
Copyright (c) 2024 UIUC TML Lab
153+
154+
Permission is hereby granted, free of charge, to any person obtaining a copy
155+
of this software and associated documentation files (the "Software"), to deal
156+
in the Software without restriction, including without limitation the rights
157+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
158+
copies of the Software, and to permit persons to whom the Software is
159+
furnished to do so, subject to the following conditions:
160+
161+
The above copyright notice and this permission notice shall be included in all
162+
copies or substantial portions of the Software.
163+
164+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
165+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
166+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
167+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
168+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
169+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
170+
SOFTWARE.
171+
172+
-------------------------------------------------------------------------------
173+
146174
Project: SDMGrad
147175
Source: https://github.com/OptMN-Lab/SDMGrad/blob/main/methods/weight_methods.py
148176
Used in: src/torchjd/aggregation/_sdmgrad.py
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
:hide-toc:
2+
3+
ExcessMTL
4+
=========
5+
6+
.. autoclass:: torchjd.aggregation.ExcessMTLWeighting
7+
:members: __call__, reset

docs/source/docs/aggregation/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Abstract base classes
3030
constant.rst
3131
cr_mogm.rst
3232
dualproj.rst
33+
excess_mtl.rst
3334
fairgrad.rst
3435
graddrop.rst
3536
gradvac.rst

src/torchjd/aggregation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from ._constant import Constant, ConstantWeighting
4646
from ._cr_mogm import CRMOGMWeighting
4747
from ._dualproj import DualProj, DualProjWeighting
48+
from ._excess_mtl import ExcessMTLWeighting
4849
from ._fairgrad import FairGrad, FairGradWeighting
4950
from ._graddrop import GradDrop
5051
from ._gradvac import GradVac, GradVacWeighting
@@ -74,6 +75,7 @@
7475
"CRMOGMWeighting",
7576
"DualProj",
7677
"DualProjWeighting",
78+
"ExcessMTLWeighting",
7779
"FairGrad",
7880
"FairGradWeighting",
7981
"GradDrop",
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Partly adapted from https://github.com/uiuctml/ExcessMTL — MIT License, Copyright (c) 2024 UIUC TML Lab.
2+
# See NOTICES for the full license text.
3+
from __future__ import annotations
4+
5+
from typing import cast
6+
7+
import torch
8+
from torch import Tensor
9+
10+
from torchjd._mixins import Stateful
11+
from torchjd.aggregation._mixins import _NonDifferentiable
12+
from torchjd.linalg import Matrix
13+
14+
from ._weighting_bases import _MatrixWeighting
15+
16+
17+
class ExcessMTLWeighting(_MatrixWeighting, Stateful, _NonDifferentiable):
18+
r"""
19+
:class:`~torchjd.Stateful`
20+
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] from `Robust
21+
Multi-Task Learning with Excess Risks
22+
<https://proceedings.mlr.press/v235/he24n.html>`_ (ICML 2024).
23+
24+
At each call, task weights are updated via an exponentiated gradient step (Equation 9) driven
25+
by per-task excess risk estimates. The excess risk for task :math:`i` is approximated via a
26+
second-order Taylor expansion (Equations 6-7):
27+
28+
:param robust_step_size: Step size :math:`\eta_\alpha` for the exponentiated weight update.
29+
Must be positive.
30+
:param n_warmup_steps: Number of forward calls during which weights stay uniform
31+
(:math:`[1/m, \ldots, 1/m]`) and gradient statistics are collected. The baseline excess
32+
risk is set to the average excess risk observed during warmup. When ``0`` (default), the
33+
first call's excess risk is used as the baseline and weights are updated immediately
34+
(matching the official implementation).
35+
36+
.. warning::
37+
The state tensor :math:`S \in \mathbb{R}^{m \times n}` accumulates squared gradients
38+
across **all** calls, where :math:`n` is the total number of model parameters. For large
39+
models this can be a significant memory cost. Call :meth:`reset` between experiments.
40+
41+
.. note::
42+
The weight update is adapted from the `official implementation
43+
<https://github.com/uiuctml/ExcessMTL>`_ and `LibMTL
44+
<https://github.com/median-research-group/LibMTL/blob/main/LibMTL/weighting/ExcessMTL.py>`_.
45+
The warmup strategy follows Appendix C.1 of the paper, which recommends collecting
46+
gradient statistics for several epochs before beginning weight updates; set
47+
``n_warmup_steps`` accordingly (e.g. ``3 * len(dataloader)``).
48+
49+
.. admonition:: Example
50+
51+
.. testcode::
52+
53+
import torch
54+
from torch.nn import Linear, MSELoss, ReLU, Sequential
55+
from torch.optim import SGD
56+
57+
from torchjd import autojac
58+
from torchjd.aggregation import ExcessMTLWeighting, WeightedAggregator
59+
from torchjd.autojac import jac_to_grad
60+
61+
inputs = torch.randn(8, 5)
62+
targets = torch.randn(8, 2)
63+
64+
model = Sequential(Linear(5, 4), ReLU(), Linear(4, 2))
65+
optimizer = SGD(model.parameters())
66+
criterion = MSELoss()
67+
aggregator = WeightedAggregator(ExcessMTLWeighting())
68+
69+
outputs = model(inputs)
70+
losses = [criterion(outputs[:, i], targets[:, i]) for i in range(2)]
71+
autojac.backward(losses)
72+
jac_to_grad(model.parameters(), aggregator)
73+
optimizer.step()
74+
optimizer.zero_grad()
75+
"""
76+
77+
def __init__(
78+
self,
79+
robust_step_size: float = 1.0,
80+
n_warmup_steps: int = 0,
81+
) -> None:
82+
super().__init__()
83+
self.robust_step_size = robust_step_size
84+
self.n_warmup_steps = n_warmup_steps
85+
self.register_buffer("_weights", None)
86+
self.register_buffer("_grad_sum", None)
87+
self.register_buffer("_initial_w", None)
88+
self.register_buffer("_warmup_w_sum", None)
89+
self.register_buffer("_n_steps", torch.zeros((), dtype=torch.long))
90+
self._state_key: tuple[int, int, torch.dtype, torch.device] | None = None
91+
92+
@property
93+
def robust_step_size(self) -> float:
94+
return self._robust_step_size
95+
96+
@robust_step_size.setter
97+
def robust_step_size(self, value: float) -> None:
98+
if value <= 0.0:
99+
raise ValueError(
100+
f"Attribute `robust_step_size` must be positive. Found robust_step_size={value!r}."
101+
)
102+
self._robust_step_size = value
103+
104+
@property
105+
def n_warmup_steps(self) -> int:
106+
return self._n_warmup_steps
107+
108+
@n_warmup_steps.setter
109+
def n_warmup_steps(self, value: int) -> None:
110+
if value < 0:
111+
raise ValueError(
112+
f"Attribute `n_warmup_steps` must be non-negative. Found n_warmup_steps={value!r}."
113+
)
114+
self._n_warmup_steps = value
115+
116+
def reset(self) -> None:
117+
"""Clears all state so the next forward starts from uniform weights and re-enters
118+
warmup."""
119+
120+
self._weights = None
121+
self._grad_sum = None
122+
self._initial_w = None
123+
self._warmup_w_sum = None
124+
self._n_steps.zero_()
125+
self._state_key = None
126+
127+
def forward(self, matrix: Matrix, /) -> Tensor:
128+
self._ensure_state(matrix)
129+
130+
# Accumulate squared gradients for AdaGrad-style diagonal Hessian (Equation 7)
131+
grad_sum = cast(Tensor, self._grad_sum)
132+
grad_sum = grad_sum + matrix.detach() ** 2
133+
self._grad_sum = grad_sum
134+
135+
# Excess risk proxy: Ê_i ≈ g_i^T H_i^{-1} g_i (Equation 6)
136+
h = torch.sqrt(grad_sum + 1e-7)
137+
w = (matrix.detach() ** 2 / h).sum(dim=1) # shape [m]
138+
139+
n_steps = int(self._n_steps.item())
140+
self._n_steps = self._n_steps + 1
141+
142+
# Warmup: collect excess risk stats but return uniform weights
143+
if n_steps < self._n_warmup_steps:
144+
warmup_w_sum = self._warmup_w_sum
145+
self._warmup_w_sum = w if warmup_w_sum is None else cast(Tensor, warmup_w_sum) + w
146+
return cast(Tensor, self._weights)
147+
148+
# Set baseline on the first non-warmup call
149+
if self._initial_w is None:
150+
if self._n_warmup_steps > 0:
151+
# Average excess risk observed during warmup (Appendix C.1)
152+
self._initial_w = cast(Tensor, self._warmup_w_sum) / self._n_warmup_steps
153+
w = w / (cast(Tensor, self._initial_w) + 1e-7)
154+
else:
155+
# Official impl behaviour: first call's excess is the baseline; use w raw
156+
self._initial_w = w
157+
else:
158+
w = w / (cast(Tensor, self._initial_w) + 1e-7)
159+
160+
# Exponentiated gradient weight update (Equation 9)
161+
weights = cast(Tensor, self._weights)
162+
weights = weights * torch.exp(w * self._robust_step_size)
163+
weights = weights / weights.sum()
164+
self._weights = weights
165+
return weights
166+
167+
def _ensure_state(self, matrix: Matrix) -> None:
168+
key = (matrix.shape[0], matrix.shape[1], matrix.dtype, matrix.device)
169+
if self._state_key == key and self._grad_sum is not None:
170+
return
171+
m, n = matrix.shape
172+
self._grad_sum = matrix.new_zeros(m, n)
173+
self._weights = matrix.new_full((m,), 1.0 / m)
174+
self._initial_w = None
175+
self._warmup_w_sum = None
176+
self._n_steps.zero_()
177+
self._state_key = key
178+
179+
def __repr__(self) -> str:
180+
return (
181+
f"{self.__class__.__name__}("
182+
f"robust_step_size={self.robust_step_size!r}, "
183+
f"n_warmup_steps={self.n_warmup_steps!r})"
184+
)

0 commit comments

Comments
 (0)