Skip to content

Commit ceaaf12

Browse files
committed
test: Add ExcessMTL aggregator coverage and fix redundant casts
1 parent 20b7f08 commit ceaaf12

2 files changed

Lines changed: 34 additions & 3 deletions

File tree

src/torchjd/aggregation/_excess_mtl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,12 @@ def forward(self, matrix: Matrix, /) -> Tensor:
124124
if self._n_warmup_steps > 0:
125125
# Average excess risk observed during warmup (Appendix C.1)
126126
self._initial_w = cast(Tensor, self._warmup_w_sum) / self._n_warmup_steps
127-
w = w / (cast(Tensor, self._initial_w) + 1e-7)
127+
w = w / (self._initial_w + 1e-7)
128128
else:
129129
# Official impl behavior: first call's excess is the baseline; use w raw
130130
self._initial_w = w
131131
else:
132-
w = w / (cast(Tensor, self._initial_w) + 1e-7)
132+
w = w / (self._initial_w + 1e-7)
133133

134134
# Exponentiated gradient weight update (Equation 9)
135135
weights = cast(Tensor, self._weights)

tests/unit/aggregation/test_excess_mtl.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch.testing import assert_close
44
from utils.tensors import randn_, tensor_
55

6-
from torchjd.aggregation._excess_mtl import ExcessMTLWeighting
6+
from torchjd.aggregation._excess_mtl import ExcessMTL, ExcessMTLWeighting
77

88

99
def test_representations() -> None:
@@ -190,3 +190,34 @@ def test_non_differentiable() -> None:
190190
W = ExcessMTLWeighting()
191191
weights = W(J)
192192
assert not weights.requires_grad
193+
194+
195+
# ExcessMTL (aggregator wrapper) tests
196+
197+
198+
def test_excess_mtl_representations() -> None:
199+
agg = ExcessMTL(robust_step_size=2.0, n_warmup_steps=3)
200+
assert repr(agg) == "ExcessMTL(robust_step_size=2.0, n_warmup_steps=3)"
201+
202+
203+
def test_excess_mtl_properties_delegate() -> None:
204+
agg = ExcessMTL(robust_step_size=1.0, n_warmup_steps=0)
205+
assert agg.robust_step_size == 1.0
206+
assert agg.n_warmup_steps == 0
207+
208+
agg.robust_step_size = 0.5
209+
assert agg.robust_step_size == 0.5
210+
assert agg.weighting.robust_step_size == 0.5
211+
212+
agg.n_warmup_steps = 5
213+
assert agg.n_warmup_steps == 5
214+
assert agg.weighting.n_warmup_steps == 5
215+
216+
217+
def test_excess_mtl_reset_delegates() -> None:
218+
J = randn_((3, 8))
219+
agg = ExcessMTL(n_warmup_steps=0)
220+
first = agg(J)
221+
agg(J)
222+
agg.reset()
223+
assert_close(first, agg(J))

0 commit comments

Comments
 (0)