Skip to content

Commit 4dcf732

Browse files
authored
test(aggregation): Add NashMTL tests (#318)
* Add nash_mtl_matrices in _inputs.py * Add TestNashMTL * Add test_nash_mtl_reset
1 parent e42d677 commit 4dcf732

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

tests/unit/aggregation/_inputs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,14 @@ def _sample_semi_orthonormal_complement(Q: Tensor) -> Tensor:
160160

161161
scaled_matrices_2_plus_rows = [matrix for matrix in scaled_matrices if matrix.shape[0] >= 2]
162162
typical_matrices_2_plus_rows = [matrix for matrix in typical_matrices if matrix.shape[0] >= 2]
163+
164+
# It seems that NashMTL does not work for matrices with 1 row, so we make different matrices for it.
165+
_nashmtl_dims = [
166+
(3, 1, 1),
167+
(4, 3, 1),
168+
(4, 3, 2),
169+
(4, 3, 3),
170+
(9, 11, 5),
171+
(9, 11, 9),
172+
]
173+
nash_mtl_matrices = [_sample_matrix(m, n, r) for m, n, r in _nashmtl_dims]

tests/unit/aggregation/test_nash_mtl.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,56 @@
1+
import torch
2+
from pytest import mark
3+
from torch import Tensor
4+
from torch.testing import assert_close
5+
16
from torchjd.aggregation import NashMTL
27

8+
from ._inputs import nash_mtl_matrices
9+
from ._property_testers import ExpectedStructureProperty
10+
11+
12+
def _make_aggregator(matrix: Tensor) -> NashMTL:
13+
return NashMTL(n_tasks=matrix.shape[0])
14+
15+
16+
_aggregators = [_make_aggregator(matrix) for matrix in nash_mtl_matrices]
17+
18+
19+
@mark.filterwarnings(
20+
"ignore:Solution may be inaccurate.",
21+
"ignore:You are solving a parameterized problem that is not DPP.",
22+
)
23+
class TestNashMTL(ExpectedStructureProperty):
24+
# Override the parametrization of `test_expected_structure_property` to make the test use the
25+
# right aggregator with each matrix.
26+
27+
# Note that as opposed to most aggregators, the ExpectedStructureProperty is only tested with
28+
# non-scaled matrices, and with matrices of > 1 row. Otherwise, NashMTL fails.
29+
@classmethod
30+
@mark.parametrize(["aggregator", "matrix"], zip(_aggregators, nash_mtl_matrices))
31+
def test_expected_structure_property(cls, aggregator: NashMTL, matrix: Tensor):
32+
cls._assert_expected_structure_property(aggregator, matrix)
33+
34+
35+
@mark.filterwarnings("ignore: You are solving a parameterized problem that is not DPP.")
36+
def test_nash_mtl_reset():
37+
"""
38+
Tests that the reset method of NashMTL correctly resets its internal state, by verifying that
39+
the result is the same after reset as it is right after instantiation.
40+
41+
To ensure that the aggregations are not all the same, we create different matrices to aggregate.
42+
"""
43+
44+
matrices = [torch.randn(3, 5) for _ in range(4)]
45+
aggregator = NashMTL(n_tasks=3, update_weights_every=3)
46+
expecteds = [aggregator(matrix) for matrix in matrices]
47+
48+
aggregator.reset()
49+
results = [aggregator(matrix) for matrix in matrices]
50+
51+
for result, expected in zip(results, expecteds):
52+
assert_close(result, expected)
53+
354

455
def test_representations():
556
A = NashMTL(n_tasks=2)

0 commit comments

Comments
 (0)