Skip to content

Commit f227121

Browse files
test(aggregation): Add StrongStationarityProperty (#257)
* Change matrix generation in _inputs.py to be able to generate strongly stationary matrices, strictly weakly stationary matrices and non weakly-stationary matrices * Change typical matrices to contain the newly added types of matrices * Rename a few things in _inputs.py to make lines shorter * Increase tolerance of some existing tests (LinearUnderScalingProperty and test_equivalence_mean of CAGrad) to make them pass on the new typical_matrices * Add StrongStationarityProperty in _property_testers.py * Make Constant, DualProj, Mean, Random, Sum and UPGrad extend StrongStationarityProperty. Note that MGDA should not and does not pass these tests. CAGrad should pass these tests but does not (probably due to some implementation issue of CAGrad). AlignedMTL passes the tests but it's hard to tell if it should in theory. For a few other aggregators, we did not try the property tester, and we do not know if they have the property in theory. EDIT: this might have changed after changing the dimensions of the matrices that we test.
1 parent be7bdad commit f227121

File tree

9 files changed

+154
-58
lines changed

9 files changed

+154
-58
lines changed

tests/unit/aggregation/_inputs.py

Lines changed: 92 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,65 +3,107 @@
33
from torch.nn.functional import normalize
44

55

6-
def _generate_matrix(m: int, n: int, rank: int) -> Tensor:
7-
"""Generates a random matrix A of shape [m, n] with provided rank."""
6+
def _sample_matrix(m: int, n: int, rank: int) -> Tensor:
7+
"""Samples a random matrix A of shape [m, n] with provided rank."""
88

9-
U = _generate_orthonormal_matrix(m)
10-
Vt = _generate_orthonormal_matrix(n)
9+
U = _sample_orthonormal_matrix(m)
10+
Vt = _sample_orthonormal_matrix(n)
1111
S = torch.diag(torch.abs(torch.randn([rank])))
1212
A = U[:, :rank] @ S @ Vt[:rank, :]
1313
return A
1414

1515

16-
def _generate_strong_stationary_matrix(m: int, n: int) -> Tensor:
16+
def _sample_strong_matrix(m: int, n: int, rank: int) -> Tensor:
1717
"""
18-
Generates a random matrix A of shape [m, n] with rank min(n, m - 1), such that there exists a
19-
vector 0<v with v^T A = 0.
18+
Samples a random strongly stationary matrix A of shape [m, n] with provided rank.
19+
20+
Definition: A matrix A is said to be strongly stationary if there exists a vector 0 < v such
21+
that v^T A = 0.
22+
23+
This is done by sampling a positive v, and by then sampling a matrix orthogonal to v.
2024
"""
2125

26+
assert 1 < m
27+
assert 0 < rank <= min(m - 1, n)
28+
2229
v = torch.abs(torch.randn([m]))
23-
return _generate_matrix_orthogonal_to_vector(v, n)
30+
U1 = normalize(v, dim=0).unsqueeze(1)
31+
U2 = _sample_semi_orthonormal_complement(U1)
32+
Vt = _sample_orthonormal_matrix(n)
33+
S = torch.diag(torch.abs(torch.randn([rank])))
34+
A = U2[:, :rank] @ S @ Vt[:rank, :]
35+
return A
2436

2537

26-
def _generate_weak_stationary_matrix(m: int, n: int) -> Tensor:
38+
def _sample_strictly_weak_matrix(m: int, n: int, rank: int) -> Tensor:
2739
"""
28-
Generates a random matrix A of shape [m, n] with rank min(n, m - 1), such that there exists a
29-
vector 0<=v with one coordinate equal to 0 and such that v^T A = 0.
40+
Samples a random strictly weakly stationary matrix A of shape [m, n] with provided rank.
41+
42+
Definition: A matrix A is said to be weakly stationary if there exists a vector 0 <= v, v != 0,
43+
such that v^T A = 0.
3044
31-
Note that if multiple coordinates of v were equal to 0, the generated matrix would still be weak
32-
stationary, but here we only set one of them to 0 for simplicity.
45+
Definition: A matrix A is said to be strictly weakly stationary if it is weakly stationary and
46+
not strongly stationary, i.e. if there exists a vector 0 <= v, v != 0, such that v^T A = 0 and
47+
there exists no vector 0 < w with w^T A = 0.
48+
49+
This is done by sampling two unit-norm vectors v, v', whose sum u is a positive vector. These
50+
two vectors are also non-negative and non-zero, and are furthermore orthogonal. Then, a matrix
51+
A, orthogonal to v, is sampled. By its orthogonality to v, A is weakly stationary. Moreover,
52+
since v' is a non-negative left-singular vector of A with positive singular value s, any 0 < w
53+
satisfies w^T A != 0. Otherwise, we would have 0 = w^T A A^T v' = s w^T v' > 0, which is a
54+
contradiction. A is thus also not strongly stationary.
3355
"""
3456

35-
v = torch.abs(torch.randn([m]))
36-
i = torch.randint(0, m, [])
37-
v[i] = 0.0
38-
return _generate_matrix_orthogonal_to_vector(v, n)
57+
assert 1 < m
58+
assert 0 < rank <= min(m - 1, n)
59+
60+
u = torch.abs(torch.randn([m]))
61+
split_index = torch.randint(1, m, []).item()
62+
shuffled_range = torch.randperm(m)
63+
v = torch.zeros(m)
64+
v[shuffled_range[:split_index]] = normalize(u[shuffled_range[:split_index]], dim=0)
65+
v_prime = torch.zeros(m)
66+
v_prime[shuffled_range[split_index:]] = normalize(u[shuffled_range[split_index:]], dim=0)
67+
U1 = torch.stack([v, v_prime]).T
68+
U2 = _sample_semi_orthonormal_complement(U1)
69+
U = torch.hstack([U1, U2])
70+
Vt = _sample_orthonormal_matrix(n)
71+
S = torch.diag(torch.abs(torch.randn([rank])))
72+
A = U[:, 1 : rank + 1] @ S @ Vt[:rank, :]
73+
return A
3974

4075

41-
def _generate_matrix_orthogonal_to_vector(v: Tensor, n: int) -> Tensor:
76+
def _sample_non_weak_matrix(m: int, n: int, rank: int) -> Tensor:
4277
"""
43-
Generates a random matrix A of shape [len(v), n] with rank min(n, len(v) - 1) such that
44-
v^T A = 0.
78+
Samples a random non weakly-stationary matrix A of shape [m, n] with provided rank.
79+
80+
This is done by sampling a positive u, and by then sampling a matrix A that has u as one of its
81+
left-singular vectors, with positive singular value s. Any 0 <= v, v != 0, satisfies v^T A != 0.
82+
Otherwise, we would have 0 = v^T A A^T u = s v^T u > 0, which is a contradiction. A is thus not
83+
weakly stationary.
4584
"""
4685

47-
rank = min(n, len(v) - 1)
48-
Q = normalize(v, dim=0).unsqueeze(1)
49-
U = _generate_semi_orthonormal_complement(Q)
50-
Vt = _generate_orthonormal_matrix(n)
86+
assert 0 < rank <= min(m, n)
87+
88+
u = torch.abs(torch.randn([m]))
89+
U1 = normalize(u, dim=0).unsqueeze(1)
90+
U2 = _sample_semi_orthonormal_complement(U1)
91+
U = torch.hstack([U1, U2])
92+
Vt = _sample_orthonormal_matrix(n)
5193
S = torch.diag(torch.abs(torch.randn([rank])))
5294
A = U[:, :rank] @ S @ Vt[:rank, :]
5395
return A
5496

5597

56-
def _generate_orthonormal_matrix(dim: int) -> Tensor:
57-
"""Uniformly generates a random orthonormal matrix of shape [dim, dim]."""
98+
def _sample_orthonormal_matrix(dim: int) -> Tensor:
99+
"""Uniformly samples a random orthonormal matrix of shape [dim, dim]."""
58100

59-
return _generate_semi_orthonormal_complement(torch.zeros([dim, 0]))
101+
return _sample_semi_orthonormal_complement(torch.zeros([dim, 0]))
60102

61103

62-
def _generate_semi_orthonormal_complement(Q: Tensor) -> Tensor:
104+
def _sample_semi_orthonormal_complement(Q: Tensor) -> Tensor:
63105
"""
64-
Uniformly generates a random semi-orthonormal matrix Q' (i.e. Q'^T Q' = I) of shape [m, m-k]
106+
Uniformly samples a random semi-orthonormal matrix Q' (i.e. Q'^T Q' = I) of shape [m, m-k]
65107
orthogonal to Q, i.e. such that the concatenation [Q, Q'] is an orthonormal matrix.
66108
67109
:param Q: A semi-orthonormal matrix (i.e. Q^T Q = I) of shape [m, k], with k <= m.
@@ -77,7 +119,7 @@ def _generate_semi_orthonormal_complement(Q: Tensor) -> Tensor:
77119
return Q_prime
78120

79121

80-
_matrix_dimension_triples = [
122+
_normal_dims = [
81123
(1, 1, 1),
82124
(4, 3, 1),
83125
(4, 3, 2),
@@ -86,32 +128,35 @@ def _generate_semi_orthonormal_complement(Q: Tensor) -> Tensor:
86128
(9, 11, 9),
87129
]
88130

89-
_zero_matrices_shapes = [
90-
(1, 1),
91-
(4, 3),
92-
(9, 11),
131+
_zero_dims = [
132+
(1, 1, 0),
133+
(4, 3, 0),
134+
(9, 11, 0),
93135
]
94136

95-
_stationary_matrices_shapes = [
96-
(5, 3),
97-
(9, 11),
137+
_stationarity_dims = [
138+
(20, 10, 10),
139+
(20, 10, 5),
140+
(20, 10, 1),
141+
(20, 100, 1),
142+
(20, 100, 19),
98143
]
99144

100145
_scales = [0.0, 1e-10, 1e3, 1e5, 1e10, 1e15]
101146

102-
# Fix seed to fix randomness of matrix generation
147+
# Fix seed to fix randomness of matrix sampling
103148
torch.manual_seed(0)
104149

105-
matrices = [_generate_matrix(m, n, rank) for m, n, rank in _matrix_dimension_triples]
150+
matrices = [_sample_matrix(m, n, r) for m, n, r in _normal_dims]
151+
zero_matrices = [torch.zeros([m, n]) for m, n, _ in _zero_dims]
152+
strong_matrices = [_sample_strong_matrix(m, n, r) for m, n, r in _stationarity_dims]
153+
strictly_weak_matrices = [_sample_strictly_weak_matrix(m, n, r) for m, n, r in _stationarity_dims]
154+
non_weak_matrices = [_sample_non_weak_matrix(m, n, r) for m, n, r in _stationarity_dims]
155+
106156
scaled_matrices = [scale * matrix for scale in _scales for matrix in matrices]
107-
zero_matrices = [torch.zeros([m, n]) for m, n in _zero_matrices_shapes]
108-
strong_stationary_matrices = [
109-
_generate_strong_stationary_matrix(m, n) for m, n in _stationary_matrices_shapes
110-
]
111-
weak_stationary_matrices = [
112-
_generate_weak_stationary_matrix(m, n) for m, n in _stationary_matrices_shapes
113-
]
114-
typical_matrices = zero_matrices + matrices + weak_stationary_matrices + strong_stationary_matrices
157+
158+
non_strong_matrices = strictly_weak_matrices + non_weak_matrices
159+
typical_matrices = zero_matrices + matrices + strong_matrices + non_strong_matrices
115160

116161
scaled_matrices_2_plus_rows = [matrix for matrix in scaled_matrices if matrix.shape[0] >= 2]
117162
typical_matrices_2_plus_rows = [matrix for matrix in typical_matrices if matrix.shape[0] >= 2]

tests/unit/aggregation/_property_testers.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from torchjd.aggregation import Aggregator
77

8-
from ._inputs import scaled_matrices, typical_matrices
8+
from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices
99

1010

1111
class ExpectedStructureProperty:
@@ -101,4 +101,25 @@ def _assert_linear_under_scaling_property(
101101
x = aggregator(torch.diag(alpha * c1 + beta * c2) @ matrix)
102102
expected = alpha * x1 + beta * x2
103103

104-
assert_close(x, expected, atol=8e-03, rtol=0)
104+
assert_close(x, expected, atol=1e-02, rtol=0)
105+
106+
107+
class StrongStationarityProperty:
108+
"""
109+
This class tests empirically that a given `Aggregator` is strongly stationary.
110+
111+
An aggregator `A` is strongly stationary if for any matrix `J` with `A(J)=0`, `J` is strongly
112+
stationary, i.e., there exists `0<w` such that `J^T w=0`. In this class, we test the
113+
contraposition: whenever `J` is not strongly stationary, we must have `A(J) != 0`.
114+
"""
115+
116+
@classmethod
117+
@mark.parametrize("matrix", non_strong_matrices)
118+
def test_stationarity_property(cls, aggregator: Aggregator, matrix: Tensor):
119+
cls._assert_stationarity_property(aggregator, matrix)
120+
121+
@staticmethod
122+
def _assert_stationarity_property(aggregator: Aggregator, matrix: Tensor) -> None:
123+
vector = aggregator(matrix)
124+
norm = vector.norm().item()
125+
assert norm > 1e-03

tests/unit/aggregation/test_cagrad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_equivalence_mean(matrix: Tensor):
3333
result = ca_grad(matrix)
3434
expected = mean(matrix)
3535

36-
assert_close(result, expected)
36+
assert_close(result, expected, atol=2e-1, rtol=0)
3737

3838

3939
@mark.parametrize(

tests/unit/aggregation/test_constant.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@
77

88
from torchjd.aggregation import Constant
99

10-
from ._inputs import scaled_matrices, typical_matrices
11-
from ._property_testers import ExpectedStructureProperty, LinearUnderScalingProperty
10+
from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices
11+
from ._property_testers import (
12+
ExpectedStructureProperty,
13+
LinearUnderScalingProperty,
14+
StrongStationarityProperty,
15+
)
1216

1317
# The weights must be a vector of length equal to the number of rows in the matrix that it will be
1418
# applied to. Thus, each `Constant` instance is specific to matrices of a given number of rows. To
@@ -28,8 +32,13 @@ def _make_aggregator(matrix: Tensor) -> Constant:
2832
_matrices_2 = typical_matrices
2933
_aggregators_2 = [_make_aggregator(matrix) for matrix in _matrices_2]
3034

35+
_matrices_3 = non_strong_matrices
36+
_aggregators_3 = [_make_aggregator(matrix) for matrix in _matrices_3]
37+
3138

32-
class TestConstant(ExpectedStructureProperty, LinearUnderScalingProperty):
39+
class TestConstant(
40+
ExpectedStructureProperty, LinearUnderScalingProperty, StrongStationarityProperty
41+
):
3342
# Override the parametrization of `test_expected_structure_property` to make the test use the
3443
# right aggregator with each matrix.
3544

@@ -43,6 +52,11 @@ def test_expected_structure_property(cls, aggregator: Constant, matrix: Tensor):
4352
def test_linear_under_scaling_property(cls, aggregator: Constant, matrix: Tensor):
4453
cls._assert_linear_under_scaling_property(aggregator, matrix)
4554

55+
@classmethod
56+
@mark.parametrize(["aggregator", "matrix"], zip(_aggregators_3, _matrices_3))
57+
def test_stationarity_property(cls, aggregator: Constant, matrix: Tensor):
58+
cls._assert_stationarity_property(aggregator, matrix)
59+
4660

4761
@mark.parametrize(
4862
["weights_shape", "expectation"],

tests/unit/aggregation/test_dualproj.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,16 @@
77
ExpectedStructureProperty,
88
NonConflictingProperty,
99
PermutationInvarianceProperty,
10+
StrongStationarityProperty,
1011
)
1112

1213

1314
@mark.parametrize("aggregator", [DualProj()])
1415
class TestDualProj(
15-
ExpectedStructureProperty, NonConflictingProperty, PermutationInvarianceProperty
16+
ExpectedStructureProperty,
17+
NonConflictingProperty,
18+
PermutationInvarianceProperty,
19+
StrongStationarityProperty,
1620
):
1721
pass
1822

tests/unit/aggregation/test_mean.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66
ExpectedStructureProperty,
77
LinearUnderScalingProperty,
88
PermutationInvarianceProperty,
9+
StrongStationarityProperty,
910
)
1011

1112

1213
@mark.parametrize("aggregator", [Mean()])
1314
class TestMean(
14-
ExpectedStructureProperty, PermutationInvarianceProperty, LinearUnderScalingProperty
15+
ExpectedStructureProperty,
16+
PermutationInvarianceProperty,
17+
LinearUnderScalingProperty,
18+
StrongStationarityProperty,
1519
):
1620
pass
1721

tests/unit/aggregation/test_random.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
from torchjd.aggregation import Random
44

5-
from ._property_testers import ExpectedStructureProperty
5+
from ._property_testers import ExpectedStructureProperty, StrongStationarityProperty
66

77

88
@mark.parametrize("aggregator", [Random()])
9-
class TestRandom(ExpectedStructureProperty):
9+
class TestRandom(ExpectedStructureProperty, StrongStationarityProperty):
1010
pass
1111

1212

tests/unit/aggregation/test_sum.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,17 @@
66
ExpectedStructureProperty,
77
LinearUnderScalingProperty,
88
PermutationInvarianceProperty,
9+
StrongStationarityProperty,
910
)
1011

1112

1213
@mark.parametrize("aggregator", [Sum()])
13-
class TestSum(ExpectedStructureProperty, PermutationInvarianceProperty, LinearUnderScalingProperty):
14+
class TestSum(
15+
ExpectedStructureProperty,
16+
PermutationInvarianceProperty,
17+
LinearUnderScalingProperty,
18+
StrongStationarityProperty,
19+
):
1420
pass
1521

1622

tests/unit/aggregation/test_upgrad.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
LinearUnderScalingProperty,
99
NonConflictingProperty,
1010
PermutationInvarianceProperty,
11+
StrongStationarityProperty,
1112
)
1213

1314

@@ -17,6 +18,7 @@ class TestUPGrad(
1718
NonConflictingProperty,
1819
PermutationInvarianceProperty,
1920
LinearUnderScalingProperty,
21+
StrongStationarityProperty,
2022
):
2123
pass
2224

0 commit comments

Comments
 (0)