Skip to content

Commit 0b31273

Browse files
committed
test: Add coverage for loss+ size validation
Add test_loss_plus_normalization_validates_size to cover line 267 (the ValueError in _normalize_gramian_loss_plus for size mismatch).
1 parent 1ffcfa2 commit 0b31273

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

tests/unit/aggregation/test_mgda.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,16 @@ def test_loss_normalization_validates_size():
226226
weighting(gramian)
227227

228228

229+
def test_loss_plus_normalization_validates_size():
230+
"""Test that loss+ normalization validates losses size matches gramian."""
231+
weighting = MGDAWeighting(norm_type="loss+")
232+
weighting.set_losses(torch.tensor([1.0, 2.0, 3.0])) # 3 losses
233+
gramian = torch.tensor([[4.0, 2.0], [2.0, 9.0]]) # 2x2 gramian
234+
235+
with raises(ValueError, match=r"Number of losses .* must match"):
236+
weighting(gramian)
237+
238+
229239
def test_loss_normalization_weights_sum_to_one():
230240
"""Test that loss normalization produces weights that sum to 1."""
231241
weighting = MGDAWeighting(norm_type="loss")

0 commit comments

Comments
 (0)