Skip to content

Commit 5a373b4

Browse files
feat(aggregation): Add ExcessMTL (#747)
1 parent 47018db commit 5a373b4

7 files changed

Lines changed: 503 additions & 0 deletions

File tree

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,15 @@ changelog does not include internal changes that do not affect the user.
88

99
## [Unreleased]
1010

11+
### Added
12+
13+
- Added `ExcessMTL` and `ExcessMTLWeighting` from [Robust Multi-Task Learning with Excess
14+
Risks](https://proceedings.mlr.press/v235/he24n.html) (ICML 2024). `ExcessMTLWeighting` is a
15+
stateful `Weighting` that maintains task weights across calls via an exponentiated gradient update
16+
driven by per-task excess risk estimates. The excess risk is approximated using an AdaGrad-style
17+
diagonal Hessian. An optional `n_warmup_steps` parameter controls how many forward calls collect
18+
gradient statistics before weight updates begin.
19+
1120
## [0.16.0] - 2026-06-22
1221

1322
### Added

NOTICES

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,34 @@ SOFTWARE.
143143

144144
-------------------------------------------------------------------------------
145145

146+
Project: ExcessMTL
147+
Source: https://github.com/uiuctml/ExcessMTL/blob/main/LibMTL/LibMTL/weighting/ExcessMTL.py
148+
Used in: src/torchjd/aggregation/_excess_mtl.py
149+
150+
MIT License
151+
152+
Copyright (c) 2024 UIUC TML Lab
153+
154+
Permission is hereby granted, free of charge, to any person obtaining a copy
155+
of this software and associated documentation files (the "Software"), to deal
156+
in the Software without restriction, including without limitation the rights
157+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
158+
copies of the Software, and to permit persons to whom the Software is
159+
furnished to do so, subject to the following conditions:
160+
161+
The above copyright notice and this permission notice shall be included in all
162+
copies or substantial portions of the Software.
163+
164+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
165+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
166+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
167+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
168+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
169+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
170+
SOFTWARE.
171+
172+
-------------------------------------------------------------------------------
173+
146174
Project: SDMGrad
147175
Source: https://github.com/OptMN-Lab/SDMGrad/blob/main/methods/weight_methods.py
148176
Used in: src/torchjd/aggregation/_sdmgrad.py
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
:hide-toc:
2+
3+
ExcessMTL
4+
=========
5+
6+
.. autoclass:: torchjd.aggregation.ExcessMTL
7+
:members: __call__, reset
8+
9+
.. autoclass:: torchjd.aggregation.ExcessMTLWeighting
10+
:members: __call__, reset

docs/source/docs/aggregation/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Abstract base classes
3030
constant.rst
3131
cr_mogm.rst
3232
dualproj.rst
33+
excess_mtl.rst
3334
fairgrad.rst
3435
graddrop.rst
3536
gradvac.rst

src/torchjd/aggregation/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from ._constant import Constant, ConstantWeighting
4646
from ._cr_mogm import CRMOGMWeighting
4747
from ._dualproj import DualProj, DualProjWeighting
48+
from ._excess_mtl import ExcessMTL, ExcessMTLWeighting
4849
from ._fairgrad import FairGrad, FairGradWeighting
4950
from ._graddrop import GradDrop
5051
from ._gradvac import GradVac, GradVacWeighting
@@ -74,6 +75,8 @@
7475
"CRMOGMWeighting",
7576
"DualProj",
7677
"DualProjWeighting",
78+
"ExcessMTL",
79+
"ExcessMTLWeighting",
7780
"FairGrad",
7881
"FairGradWeighting",
7982
"GradDrop",
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# Partly adapted from https://github.com/uiuctml/ExcessMTL — MIT License, Copyright (c) 2024 UIUC TML Lab.
2+
# See NOTICES for the full license text.
3+
from __future__ import annotations
4+
5+
from typing import cast
6+
7+
import torch
8+
from torch import Tensor
9+
10+
from torchjd._mixins import Stateful
11+
from torchjd.aggregation._mixins import _NonDifferentiable
12+
from torchjd.linalg import Matrix
13+
14+
from ._aggregator_bases import WeightedAggregator
15+
from ._weighting_bases import _MatrixWeighting
16+
17+
18+
class ExcessMTLWeighting(_MatrixWeighting, Stateful, _NonDifferentiable):
19+
r"""
20+
:class:`~torchjd.Stateful`
21+
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] from `Robust
22+
Multi-Task Learning with Excess Risks
23+
<https://proceedings.mlr.press/v235/he24n.html>`_ (ICML 2024).
24+
25+
At each call, task weights are updated via an exponentiated gradient step (Equation 9) driven
26+
by per-task excess risk estimates. The excess risk for task :math:`i` is approximated via a
27+
second-order Taylor expansion (Equations 6-7).
28+
29+
:param robust_step_size: Step size :math:`\eta_\alpha` for the exponentiated weight update.
30+
Must be positive.
31+
:param n_warmup_steps: Number of forward calls during which weights stay uniform
32+
(:math:`[1/m, \ldots, 1/m]`) and gradient statistics are collected. The baseline excess
33+
risk is then set to the average excess risk observed during warmup. When ``0`` (default),
34+
the first call's excess risk is used immediately as the baseline, matching the behavior of
35+
the official implementation and LibMTL. The paper (Appendix C.1) recommends collecting
36+
statistics for 3 full epochs, i.e. ``n_warmup_steps = 3 * len(dataloader)``.
37+
38+
.. warning::
39+
The state tensor :math:`S \in \mathbb{R}^{m \times n}` accumulates squared gradients
40+
across calls, where :math:`n` is the total number of model parameters. For large
41+
models this can be a significant memory cost. Call :meth:`reset` between experiments.
42+
43+
.. note::
44+
The weight update is adapted from the `official implementation
45+
<https://github.com/uiuctml/ExcessMTL>`_ and `LibMTL
46+
<https://github.com/median-research-group/LibMTL/blob/main/LibMTL/weighting/ExcessMTL.py>`_.
47+
Unlike those implementations, which initialize task weights to ``1``, we follow the paper
48+
and initialize them to ``1/m`` so that they always lie on the probability simplex.
49+
"""
50+
51+
def __init__(
52+
self,
53+
robust_step_size: float = 1.0,
54+
n_warmup_steps: int = 0,
55+
) -> None:
56+
super().__init__()
57+
self.robust_step_size = robust_step_size
58+
self.n_warmup_steps = n_warmup_steps
59+
self.register_buffer("_weights", None)
60+
self.register_buffer("_sq_grad_sum", None)
61+
self.register_buffer("_initial_w", None)
62+
self.register_buffer("_warmup_w_sum", None)
63+
self._n_steps: int = 0
64+
self._state_key: tuple[int, int, torch.dtype, torch.device] | None = None
65+
66+
@property
67+
def robust_step_size(self) -> float:
68+
return self._robust_step_size
69+
70+
@robust_step_size.setter
71+
def robust_step_size(self, value: float) -> None:
72+
if value <= 0.0:
73+
raise ValueError(
74+
f"Attribute `robust_step_size` must be positive. Found robust_step_size={value!r}."
75+
)
76+
self._robust_step_size = value
77+
78+
@property
79+
def n_warmup_steps(self) -> int:
80+
return self._n_warmup_steps
81+
82+
@n_warmup_steps.setter
83+
def n_warmup_steps(self, value: int) -> None:
84+
if value < 0:
85+
raise ValueError(
86+
f"Attribute `n_warmup_steps` must be non-negative. Found n_warmup_steps={value!r}."
87+
)
88+
self._n_warmup_steps = value
89+
90+
def forward(self, matrix: Matrix, /) -> Tensor:
91+
self._ensure_state(matrix)
92+
93+
sq_matrix = matrix.detach() ** 2
94+
95+
# Accumulate squared gradients for AdaGrad-style diagonal Hessian (Equation 7)
96+
sq_grad_sum = cast(Tensor, self._sq_grad_sum)
97+
sq_grad_sum.add_(sq_matrix)
98+
99+
# Excess risk proxy: Ê_i ≈ g_i^T H_i^{-1} g_i (Equation 6)
100+
h = torch.sqrt(sq_grad_sum + 1e-7)
101+
w = (sq_matrix / h).sum(dim=1) # shape [m]
102+
103+
# Warmup: collect excess risk stats but return uniform weights
104+
if self._n_steps < self._n_warmup_steps:
105+
cast(Tensor, self._warmup_w_sum).add_(w)
106+
self._n_steps += 1
107+
return cast(Tensor, self._weights)
108+
109+
self._n_steps += 1
110+
111+
# Set baseline on the first non-warmup call
112+
if self._initial_w is None:
113+
if self._n_warmup_steps > 0:
114+
# Average excess risk observed during warmup (Appendix C.1)
115+
self._initial_w = cast(Tensor, self._warmup_w_sum) / self._n_warmup_steps
116+
w = w / (self._initial_w + 1e-7) # Scale processing (Section 3.2)
117+
else:
118+
# Official impl behavior: first call's excess is the baseline; use w raw
119+
self._initial_w = w
120+
else:
121+
w = w / (self._initial_w + 1e-7) # Scale processing (Section 3.2)
122+
123+
# Exponentiated gradient weight update (Equation 9)
124+
weights = cast(Tensor, self._weights)
125+
weights = weights * torch.exp(w * self._robust_step_size)
126+
weights = weights / weights.sum()
127+
self._weights = weights
128+
return weights
129+
130+
def reset(self) -> None:
131+
"""Clears all state so the next forward starts from uniform weights and re-enters
132+
warmup."""
133+
134+
self._weights = None
135+
self._sq_grad_sum = None
136+
self._initial_w = None
137+
self._warmup_w_sum = None
138+
self._n_steps = 0
139+
self._state_key = None
140+
141+
def _ensure_state(self, matrix: Matrix) -> None:
142+
key = (matrix.shape[0], matrix.shape[1], matrix.dtype, matrix.device)
143+
if self._state_key == key and self._sq_grad_sum is not None:
144+
return
145+
m, n = matrix.shape
146+
self._sq_grad_sum = matrix.new_zeros(m, n)
147+
self._warmup_w_sum = matrix.new_zeros(m)
148+
self._weights = matrix.new_full((m,), 1.0 / m)
149+
self._initial_w = None
150+
self._n_steps = 0
151+
self._state_key = key
152+
153+
def __repr__(self) -> str:
154+
return (
155+
f"{self.__class__.__name__}("
156+
f"robust_step_size={self.robust_step_size!r}, "
157+
f"n_warmup_steps={self.n_warmup_steps!r})"
158+
)
159+
160+
161+
class ExcessMTL(WeightedAggregator, Stateful, _NonDifferentiable):
162+
r"""
163+
:class:`~torchjd.Stateful`
164+
:class:`~torchjd.aggregation.WeightedAggregator` from `Robust Multi-Task Learning with Excess
165+
Risks <https://proceedings.mlr.press/v235/he24n.html>`_ (ICML 2024).
166+
167+
At each call, task weights are updated via an exponentiated gradient step (Equation 9) driven
168+
by per-task excess risk estimates. See :class:`~torchjd.aggregation.ExcessMTLWeighting` for
169+
details on the algorithm and state management.
170+
171+
:param robust_step_size: Step size :math:`\eta_\alpha` for the exponentiated weight update.
172+
Must be positive.
173+
:param n_warmup_steps: Number of forward calls during which weights stay uniform
174+
(:math:`[1/m, \ldots, 1/m]`) and gradient statistics are collected. The baseline excess
175+
risk is then set to the average excess risk observed during warmup. When ``0`` (default),
176+
the first call's excess risk is used immediately as the baseline, matching the behavior of
177+
the official implementation and LibMTL. The paper (Appendix C.1) recommends collecting
178+
statistics for 3 full epochs, i.e. ``n_warmup_steps = 3 * len(dataloader)``.
179+
"""
180+
181+
weighting: ExcessMTLWeighting
182+
183+
def __init__(
184+
self,
185+
robust_step_size: float = 1.0,
186+
n_warmup_steps: int = 0,
187+
) -> None:
188+
super().__init__(ExcessMTLWeighting(robust_step_size, n_warmup_steps))
189+
190+
@property
191+
def robust_step_size(self) -> float:
192+
return self.weighting.robust_step_size
193+
194+
@robust_step_size.setter
195+
def robust_step_size(self, value: float) -> None:
196+
self.weighting.robust_step_size = value
197+
198+
@property
199+
def n_warmup_steps(self) -> int:
200+
return self.weighting.n_warmup_steps
201+
202+
@n_warmup_steps.setter
203+
def n_warmup_steps(self, value: int) -> None:
204+
self.weighting.n_warmup_steps = value
205+
206+
def reset(self) -> None:
207+
"""Clears all state so the next forward starts from uniform weights and re-enters
208+
warmup."""
209+
210+
self.weighting.reset()
211+
212+
def __repr__(self) -> str:
213+
return (
214+
f"{self.__class__.__name__}("
215+
f"robust_step_size={self.robust_step_size!r}, "
216+
f"n_warmup_steps={self.n_warmup_steps!r})"
217+
)

0 commit comments

Comments
 (0)