Skip to content

Commit 25bd5ac

Browse files
committed
Address Valerian's review comments on ExcessMTL
1 parent ceaaf12 commit 25bd5ac

1 file changed

Lines changed: 31 additions & 27 deletions

File tree

src/torchjd/aggregation/_excess_mtl.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,16 @@ class ExcessMTLWeighting(_MatrixWeighting, Stateful, _NonDifferentiable):
2424
2525
At each call, task weights are updated via an exponentiated gradient step (Equation 9) driven
2626
by per-task excess risk estimates. The excess risk for task :math:`i` is approximated via a
27-
second-order Taylor expansion (Equations 6-7):
27+
second-order Taylor expansion (Equations 6-7).
2828
2929
:param robust_step_size: Step size :math:`\eta_\alpha` for the exponentiated weight update.
3030
Must be positive.
3131
:param n_warmup_steps: Number of forward calls during which weights stay uniform
3232
(:math:`[1/m, \ldots, 1/m]`) and gradient statistics are collected. The baseline excess
33-
risk is then set to the average excess risk observed during warmup. When ``0``, the first
34-
call's excess risk is used immediately as the baseline. The default ``1`` matches the
35-
behavior of the official implementation and LibMTL. The paper (Appendix C.1) recommends
36-
collecting statistics for 3 full epochs, i.e. ``n_warmup_steps = 3 * len(dataloader)``.
33+
risk is then set to the average excess risk observed during warmup. When ``0`` (default),
34+
the first call's excess risk is used immediately as the baseline, matching the behavior of
35+
the official implementation and LibMTL. The paper (Appendix C.1) recommends collecting
36+
statistics for 3 full epochs, i.e. ``n_warmup_steps = 3 * len(dataloader)``.
3737
3838
.. warning::
3939
The state tensor :math:`S \in \mathbb{R}^{m \times n}` accumulates squared gradients
@@ -51,7 +51,7 @@ class ExcessMTLWeighting(_MatrixWeighting, Stateful, _NonDifferentiable):
5151
def __init__(
5252
self,
5353
robust_step_size: float = 1.0,
54-
n_warmup_steps: int = 1,
54+
n_warmup_steps: int = 0,
5555
) -> None:
5656
super().__init__()
5757
self.robust_step_size = robust_step_size
@@ -87,49 +87,38 @@ def n_warmup_steps(self, value: int) -> None:
8787
)
8888
self._n_warmup_steps = value
8989

90-
def reset(self) -> None:
91-
"""Clears all state so the next forward starts from uniform weights and re-enters
92-
warmup."""
93-
94-
self._weights = None
95-
self._sq_grad_sum = None
96-
self._initial_w = None
97-
self._warmup_w_sum = None
98-
self._n_steps = 0
99-
self._state_key = None
100-
10190
def forward(self, matrix: Matrix, /) -> Tensor:
10291
self._ensure_state(matrix)
10392

10493
sq_matrix = matrix.detach() ** 2
10594

10695
# Accumulate squared gradients for AdaGrad-style diagonal Hessian (Equation 7)
10796
sq_grad_sum = cast(Tensor, self._sq_grad_sum)
108-
sq_grad_sum += sq_matrix
97+
sq_grad_sum.add_(sq_matrix)
10998

11099
# Excess risk proxy: Ê_i ≈ g_i^T H_i^{-1} g_i (Equation 6)
111100
h = torch.sqrt(sq_grad_sum + 1e-7)
112101
w = (sq_matrix / h).sum(dim=1) # shape [m]
113102

114-
n_steps = self._n_steps
115-
self._n_steps += 1
116-
117103
# Warmup: collect excess risk stats but return uniform weights
118-
if n_steps < self._n_warmup_steps:
104+
if self._n_steps < self._n_warmup_steps:
119105
cast(Tensor, self._warmup_w_sum).add_(w)
106+
self._n_steps += 1
120107
return cast(Tensor, self._weights)
121108

109+
self._n_steps += 1
110+
122111
# Set baseline on the first non-warmup call
123112
if self._initial_w is None:
124113
if self._n_warmup_steps > 0:
125114
# Average excess risk observed during warmup (Appendix C.1)
126115
self._initial_w = cast(Tensor, self._warmup_w_sum) / self._n_warmup_steps
127-
w = w / (self._initial_w + 1e-7)
116+
w = w / (self._initial_w + 1e-7) # Scale processing (Section 3.2)
128117
else:
129118
# Official impl behavior: first call's excess is the baseline; use w raw
130119
self._initial_w = w
131120
else:
132-
w = w / (self._initial_w + 1e-7)
121+
w = w / (self._initial_w + 1e-7) # Scale processing (Section 3.2)
133122

134123
# Exponentiated gradient weight update (Equation 9)
135124
weights = cast(Tensor, self._weights)
@@ -138,6 +127,17 @@ def forward(self, matrix: Matrix, /) -> Tensor:
138127
self._weights = weights
139128
return weights
140129

130+
def reset(self) -> None:
131+
"""Clears all state so the next forward starts from uniform weights and re-enters
132+
warmup."""
133+
134+
self._weights = None
135+
self._sq_grad_sum = None
136+
self._initial_w = None
137+
self._warmup_w_sum = None
138+
self._n_steps = 0
139+
self._state_key = None
140+
141141
def _ensure_state(self, matrix: Matrix) -> None:
142142
key = (matrix.shape[0], matrix.shape[1], matrix.dtype, matrix.device)
143143
if self._state_key == key and self._sq_grad_sum is not None:
@@ -160,6 +160,7 @@ def __repr__(self) -> str:
160160

161161
class ExcessMTL(WeightedAggregator, Stateful, _NonDifferentiable):
162162
r"""
163+
:class:`~torchjd.Stateful`
163164
:class:`~torchjd.aggregation.WeightedAggregator` from `Robust Multi-Task Learning with Excess
164165
Risks <https://proceedings.mlr.press/v235/he24n.html>`_ (ICML 2024).
165166
@@ -170,16 +171,19 @@ class ExcessMTL(WeightedAggregator, Stateful, _NonDifferentiable):
170171
:param robust_step_size: Step size :math:`\eta_\alpha` for the exponentiated weight update.
171172
Must be positive.
172173
:param n_warmup_steps: Number of forward calls during which weights stay uniform
173-
(:math:`[1/m, \ldots, 1/m]`) and gradient statistics are collected. When ``0``, the first
174-
call's excess risk is used as the baseline immediately. Defaults to ``1``.
174+
(:math:`[1/m, \ldots, 1/m]`) and gradient statistics are collected. The baseline excess
175+
risk is then set to the average excess risk observed during warmup. When ``0`` (default),
176+
the first call's excess risk is used immediately as the baseline, matching the behavior of
177+
the official implementation and LibMTL. The paper (Appendix C.1) recommends collecting
178+
statistics for 3 full epochs, i.e. ``n_warmup_steps = 3 * len(dataloader)``.
175179
"""
176180

177181
weighting: ExcessMTLWeighting
178182

179183
def __init__(
180184
self,
181185
robust_step_size: float = 1.0,
182-
n_warmup_steps: int = 1,
186+
n_warmup_steps: int = 0,
183187
) -> None:
184188
super().__init__(ExcessMTLWeighting(robust_step_size, n_warmup_steps))
185189

0 commit comments

Comments
 (0)