@@ -24,16 +24,16 @@ class ExcessMTLWeighting(_MatrixWeighting, Stateful, _NonDifferentiable):
2424
2525 At each call, task weights are updated via an exponentiated gradient step (Equation 9) driven
2626 by per-task excess risk estimates. The excess risk for task :math:`i` is approximated via a
27- second-order Taylor expansion (Equations 6-7):
27+ second-order Taylor expansion (Equations 6-7).
2828
2929 :param robust_step_size: Step size :math:`\eta_\alpha` for the exponentiated weight update.
3030 Must be positive.
3131 :param n_warmup_steps: Number of forward calls during which weights stay uniform
3232 (:math:`[1/m, \ldots, 1/m]`) and gradient statistics are collected. The baseline excess
33- risk is then set to the average excess risk observed during warmup. When ``0``, the first
34- call's excess risk is used immediately as the baseline. The default ``1`` matches the
35- behavior of the official implementation and LibMTL. The paper (Appendix C.1) recommends
36- collecting statistics for 3 full epochs, i.e. ``n_warmup_steps = 3 * len(dataloader)``.
33+ risk is then set to the average excess risk observed during warmup. When ``0`` (default),
34+ the first call's excess risk is used immediately as the baseline, matching the behavior of
35+ the official implementation and LibMTL. The paper (Appendix C.1) recommends collecting
36+ statistics for 3 full epochs, i.e. ``n_warmup_steps = 3 * len(dataloader)``.
3737
3838 .. warning::
3939 The state tensor :math:`S \in \mathbb{R}^{m \times n}` accumulates squared gradients
@@ -51,7 +51,7 @@ class ExcessMTLWeighting(_MatrixWeighting, Stateful, _NonDifferentiable):
5151 def __init__ (
5252 self ,
5353 robust_step_size : float = 1.0 ,
54- n_warmup_steps : int = 1 ,
54+ n_warmup_steps : int = 0 ,
5555 ) -> None :
5656 super ().__init__ ()
5757 self .robust_step_size = robust_step_size
@@ -87,49 +87,38 @@ def n_warmup_steps(self, value: int) -> None:
8787 )
8888 self ._n_warmup_steps = value
8989
90- def reset (self ) -> None :
91- """Clears all state so the next forward starts from uniform weights and re-enters
92- warmup."""
93-
94- self ._weights = None
95- self ._sq_grad_sum = None
96- self ._initial_w = None
97- self ._warmup_w_sum = None
98- self ._n_steps = 0
99- self ._state_key = None
100-
10190 def forward (self , matrix : Matrix , / ) -> Tensor :
10291 self ._ensure_state (matrix )
10392
10493 sq_matrix = matrix .detach () ** 2
10594
10695 # Accumulate squared gradients for AdaGrad-style diagonal Hessian (Equation 7)
10796 sq_grad_sum = cast (Tensor , self ._sq_grad_sum )
108- sq_grad_sum += sq_matrix
97+ sq_grad_sum . add_ ( sq_matrix )
10998
11099 # Excess risk proxy: Ê_i ≈ g_i^T H_i^{-1} g_i (Equation 6)
111100 h = torch .sqrt (sq_grad_sum + 1e-7 )
112101 w = (sq_matrix / h ).sum (dim = 1 ) # shape [m]
113102
114- n_steps = self ._n_steps
115- self ._n_steps += 1
116-
117103 # Warmup: collect excess risk stats but return uniform weights
118- if n_steps < self ._n_warmup_steps :
104+ if self . _n_steps < self ._n_warmup_steps :
119105 cast (Tensor , self ._warmup_w_sum ).add_ (w )
106+ self ._n_steps += 1
120107 return cast (Tensor , self ._weights )
121108
109+ self ._n_steps += 1
110+
122111 # Set baseline on the first non-warmup call
123112 if self ._initial_w is None :
124113 if self ._n_warmup_steps > 0 :
125114 # Average excess risk observed during warmup (Appendix C.1)
126115 self ._initial_w = cast (Tensor , self ._warmup_w_sum ) / self ._n_warmup_steps
127- w = w / (self ._initial_w + 1e-7 )
116+ w = w / (self ._initial_w + 1e-7 ) # Scale processing (Section 3.2)
128117 else :
129118 # Official impl behavior: first call's excess is the baseline; use w raw
130119 self ._initial_w = w
131120 else :
132- w = w / (self ._initial_w + 1e-7 )
121+ w = w / (self ._initial_w + 1e-7 ) # Scale processing (Section 3.2)
133122
134123 # Exponentiated gradient weight update (Equation 9)
135124 weights = cast (Tensor , self ._weights )
@@ -138,6 +127,17 @@ def forward(self, matrix: Matrix, /) -> Tensor:
138127 self ._weights = weights
139128 return weights
140129
130+ def reset (self ) -> None :
131+ """Clears all state so the next forward starts from uniform weights and re-enters
132+ warmup."""
133+
134+ self ._weights = None
135+ self ._sq_grad_sum = None
136+ self ._initial_w = None
137+ self ._warmup_w_sum = None
138+ self ._n_steps = 0
139+ self ._state_key = None
140+
141141 def _ensure_state (self , matrix : Matrix ) -> None :
142142 key = (matrix .shape [0 ], matrix .shape [1 ], matrix .dtype , matrix .device )
143143 if self ._state_key == key and self ._sq_grad_sum is not None :
@@ -160,6 +160,7 @@ def __repr__(self) -> str:
160160
161161class ExcessMTL (WeightedAggregator , Stateful , _NonDifferentiable ):
162162 r"""
163+ :class:`~torchjd.Stateful`
163164 :class:`~torchjd.aggregation.WeightedAggregator` from `Robust Multi-Task Learning with Excess
164165 Risks <https://proceedings.mlr.press/v235/he24n.html>`_ (ICML 2024).
165166
@@ -170,16 +171,19 @@ class ExcessMTL(WeightedAggregator, Stateful, _NonDifferentiable):
170171 :param robust_step_size: Step size :math:`\eta_\alpha` for the exponentiated weight update.
171172 Must be positive.
172173 :param n_warmup_steps: Number of forward calls during which weights stay uniform
173- (:math:`[1/m, \ldots, 1/m]`) and gradient statistics are collected. When ``0``, the first
174- call's excess risk is used as the baseline immediately. Defaults to ``1``.
174+ (:math:`[1/m, \ldots, 1/m]`) and gradient statistics are collected. The baseline excess
175+ risk is then set to the average excess risk observed during warmup. When ``0`` (default),
176+ the first call's excess risk is used immediately as the baseline, matching the behavior of
177+ the official implementation and LibMTL. The paper (Appendix C.1) recommends collecting
178+ statistics for 3 full epochs, i.e. ``n_warmup_steps = 3 * len(dataloader)``.
175179 """
176180
177181 weighting : ExcessMTLWeighting
178182
179183 def __init__ (
180184 self ,
181185 robust_step_size : float = 1.0 ,
182- n_warmup_steps : int = 1 ,
186+ n_warmup_steps : int = 0 ,
183187 ) -> None :
184188 super ().__init__ (ExcessMTLWeighting (robust_step_size , n_warmup_steps ))
185189
0 commit comments