-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathtest_mgda.py
More file actions
72 lines (54 loc) · 2.08 KB
/
test_mgda.py
File metadata and controls
72 lines (54 loc) · 2.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from pytest import mark
from torch import Tensor
from torch.testing import assert_close
from tests.utils.tensors import ones_, randn_
from torchjd.aggregation import MGDA
from torchjd.aggregation._mgda import MGDAWeighting
from torchjd.aggregation._utils.gramian import compute_gramian
from ._asserts import (
assert_expected_structure,
assert_non_conflicting,
assert_permutation_invariant,
)
from ._inputs import scaled_matrices, typical_matrices
scaled_pairs = [(MGDA(), matrix) for matrix in scaled_matrices]
typical_pairs = [(MGDA(), matrix) for matrix in typical_matrices]
@mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
def test_expected_structure(aggregator: MGDA, matrix: Tensor):
assert_expected_structure(aggregator, matrix)
@mark.parametrize(["aggregator", "matrix"], typical_pairs)
def test_non_conflicting(aggregator: MGDA, matrix: Tensor):
assert_non_conflicting(aggregator, matrix)
@mark.parametrize(["aggregator", "matrix"], typical_pairs)
def test_permutation_invariant(aggregator: MGDA, matrix: Tensor):
assert_permutation_invariant(aggregator, matrix)
@mark.parametrize(
"shape",
[
(5, 7),
(9, 37),
(2, 14),
(32, 114),
(50, 100),
],
)
def test_mgda_satisfies_kkt_conditions(shape: tuple[int, int]):
matrix = randn_(shape)
gramian = compute_gramian(matrix)
weighting = MGDAWeighting(epsilon=1e-05, max_iters=1000)
weights = weighting(gramian)
output_direction = gramian @ weights # Stationarity
lamb = -weights @ output_direction # Complementary slackness
mu = output_direction + lamb
# Primal feasibility
positive_weights = weights[weights >= 0]
assert_close(positive_weights.norm(), weights.norm())
weights_sum = weights.sum()
assert_close(weights_sum, ones_([]))
# Dual feasibility
positive_mu = mu[mu >= 0]
assert_close(positive_mu.norm(), mu.norm(), atol=1e-02, rtol=0.0)
def test_representations():
A = MGDA(epsilon=0.001, max_iters=100)
assert repr(A) == "MGDA(epsilon=0.001, max_iters=100)"
assert str(A) == "MGDA"