We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 1ffcfa2 commit 0b31273Copy full SHA for 0b31273
tests/unit/aggregation/test_mgda.py
@@ -226,6 +226,16 @@ def test_loss_normalization_validates_size():
226
weighting(gramian)
227
228
229
+def test_loss_plus_normalization_validates_size():
230
+ """Test that loss+ normalization validates losses size matches gramian."""
231
+ weighting = MGDAWeighting(norm_type="loss+")
232
+ weighting.set_losses(torch.tensor([1.0, 2.0, 3.0])) # 3 losses
233
+ gramian = torch.tensor([[4.0, 2.0], [2.0, 9.0]]) # 2x2 gramian
234
+
235
+ with raises(ValueError, match=r"Number of losses .* must match"):
236
+ weighting(gramian)
237
238
239
def test_loss_normalization_weights_sum_to_one():
240
"""Test that loss normalization produces weights that sum to 1."""
241
weighting = MGDAWeighting(norm_type="loss")
0 commit comments