diff --git a/tests/unit/aggregation/_inputs.py b/tests/unit/aggregation/_inputs.py index beb9f1ff0..e536c3c38 100644 --- a/tests/unit/aggregation/_inputs.py +++ b/tests/unit/aggregation/_inputs.py @@ -3,65 +3,107 @@ from torch.nn.functional import normalize -def _generate_matrix(m: int, n: int, rank: int) -> Tensor: - """Generates a random matrix A of shape [m, n] with provided rank.""" +def _sample_matrix(m: int, n: int, rank: int) -> Tensor: + """Samples a random matrix A of shape [m, n] with provided rank.""" - U = _generate_orthonormal_matrix(m) - Vt = _generate_orthonormal_matrix(n) + U = _sample_orthonormal_matrix(m) + Vt = _sample_orthonormal_matrix(n) S = torch.diag(torch.abs(torch.randn([rank]))) A = U[:, :rank] @ S @ Vt[:rank, :] return A -def _generate_strong_stationary_matrix(m: int, n: int) -> Tensor: +def _sample_strong_matrix(m: int, n: int, rank: int) -> Tensor: """ - Generates a random matrix A of shape [m, n] with rank min(n, m - 1), such that there exists a - vector 0 Tensor: +def _sample_strictly_weak_matrix(m: int, n: int, rank: int) -> Tensor: """ - Generates a random matrix A of shape [m, n] with rank min(n, m - 1), such that there exists a - vector 0<=v with one coordinate equal to 0 and such that v^T A = 0. + Samples a random strictly weakly stationary matrix A of shape [m, n] with provided rank. + + Definition: A matrix A is said to be weakly stationary if there exists a vector 0 <= v, v != 0, + such that v^T A = 0. - Note that if multiple coordinates of v were equal to 0, the generated matrix would still be weak - stationary, but here we only set one of them to 0 for simplicity. + Definition: A matrix A is said to be strictly weakly stationary if it is weakly stationary and + not strongly stationary, i.e. if there exists a vector 0 <= v, v != 0, such that v^T A = 0 and + there exists no vector 0 < w with w^T A = 0. + + This is done by sampling two unit-norm vectors v, v', whose sum u is a positive vector. These + two vectors are also non-negative and non-zero, and are furthermore orthogonal. Then, a matrix + A, orthogonal to v, is sampled. By its orthogonality to v, A is weakly stationary. Moreover, + since v' is a non-negative left-singular vector of A with positive singular value s, any 0 < w + satisfies w^T A != 0. Otherwise, we would have 0 = w^T A A^T v' = s w^T v' > 0, which is a + contradiction. A is thus also not strongly stationary. """ - v = torch.abs(torch.randn([m])) - i = torch.randint(0, m, []) - v[i] = 0.0 - return _generate_matrix_orthogonal_to_vector(v, n) + assert 1 < m + assert 0 < rank <= min(m - 1, n) + + u = torch.abs(torch.randn([m])) + split_index = torch.randint(1, m, []).item() + shuffled_range = torch.randperm(m) + v = torch.zeros(m) + v[shuffled_range[:split_index]] = normalize(u[shuffled_range[:split_index]], dim=0) + v_prime = torch.zeros(m) + v_prime[shuffled_range[split_index:]] = normalize(u[shuffled_range[split_index:]], dim=0) + U1 = torch.stack([v, v_prime]).T + U2 = _sample_semi_orthonormal_complement(U1) + U = torch.hstack([U1, U2]) + Vt = _sample_orthonormal_matrix(n) + S = torch.diag(torch.abs(torch.randn([rank]))) + A = U[:, 1 : rank + 1] @ S @ Vt[:rank, :] + return A -def _generate_matrix_orthogonal_to_vector(v: Tensor, n: int) -> Tensor: +def _sample_non_weak_matrix(m: int, n: int, rank: int) -> Tensor: """ - Generates a random matrix A of shape [len(v), n] with rank min(n, len(v) - 1) such that - v^T A = 0. + Samples a random non weakly-stationary matrix A of shape [m, n] with provided rank. + + This is done by sampling a positive u, and by then sampling a matrix A that has u as one of its + left-singular vectors, with positive singular value s. Any 0 <= v, v != 0, satisfies v^T A != 0. + Otherwise, we would have 0 = v^T A A^T u = s v^T u > 0, which is a contradiction. A is thus not + weakly stationary. """ - rank = min(n, len(v) - 1) - Q = normalize(v, dim=0).unsqueeze(1) - U = _generate_semi_orthonormal_complement(Q) - Vt = _generate_orthonormal_matrix(n) + assert 0 < rank <= min(m, n) + + u = torch.abs(torch.randn([m])) + U1 = normalize(u, dim=0).unsqueeze(1) + U2 = _sample_semi_orthonormal_complement(U1) + U = torch.hstack([U1, U2]) + Vt = _sample_orthonormal_matrix(n) S = torch.diag(torch.abs(torch.randn([rank]))) A = U[:, :rank] @ S @ Vt[:rank, :] return A -def _generate_orthonormal_matrix(dim: int) -> Tensor: - """Uniformly generates a random orthonormal matrix of shape [dim, dim].""" +def _sample_orthonormal_matrix(dim: int) -> Tensor: + """Uniformly samples a random orthonormal matrix of shape [dim, dim].""" - return _generate_semi_orthonormal_complement(torch.zeros([dim, 0])) + return _sample_semi_orthonormal_complement(torch.zeros([dim, 0])) -def _generate_semi_orthonormal_complement(Q: Tensor) -> Tensor: +def _sample_semi_orthonormal_complement(Q: Tensor) -> Tensor: """ - Uniformly generates a random semi-orthonormal matrix Q' (i.e. Q'^T Q' = I) of shape [m, m-k] + Uniformly samples a random semi-orthonormal matrix Q' (i.e. Q'^T Q' = I) of shape [m, m-k] orthogonal to Q, i.e. such that the concatenation [Q, Q'] is an orthonormal matrix. :param Q: A semi-orthonormal matrix (i.e. Q^T Q = I) of shape [m, k], with k <= m. @@ -77,7 +119,7 @@ def _generate_semi_orthonormal_complement(Q: Tensor) -> Tensor: return Q_prime -_matrix_dimension_triples = [ +_normal_dims = [ (1, 1, 1), (4, 3, 1), (4, 3, 2), @@ -86,32 +128,35 @@ def _generate_semi_orthonormal_complement(Q: Tensor) -> Tensor: (9, 11, 9), ] -_zero_matrices_shapes = [ - (1, 1), - (4, 3), - (9, 11), +_zero_dims = [ + (1, 1, 0), + (4, 3, 0), + (9, 11, 0), ] -_stationary_matrices_shapes = [ - (5, 3), - (9, 11), +_stationarity_dims = [ + (20, 10, 10), + (20, 10, 5), + (20, 10, 1), + (20, 100, 1), + (20, 100, 19), ] _scales = [0.0, 1e-10, 1e3, 1e5, 1e10, 1e15] -# Fix seed to fix randomness of matrix generation +# Fix seed to fix randomness of matrix sampling torch.manual_seed(0) -matrices = [_generate_matrix(m, n, rank) for m, n, rank in _matrix_dimension_triples] +matrices = [_sample_matrix(m, n, r) for m, n, r in _normal_dims] +zero_matrices = [torch.zeros([m, n]) for m, n, _ in _zero_dims] +strong_matrices = [_sample_strong_matrix(m, n, r) for m, n, r in _stationarity_dims] +strictly_weak_matrices = [_sample_strictly_weak_matrix(m, n, r) for m, n, r in _stationarity_dims] +non_weak_matrices = [_sample_non_weak_matrix(m, n, r) for m, n, r in _stationarity_dims] + scaled_matrices = [scale * matrix for scale in _scales for matrix in matrices] -zero_matrices = [torch.zeros([m, n]) for m, n in _zero_matrices_shapes] -strong_stationary_matrices = [ - _generate_strong_stationary_matrix(m, n) for m, n in _stationary_matrices_shapes -] -weak_stationary_matrices = [ - _generate_weak_stationary_matrix(m, n) for m, n in _stationary_matrices_shapes -] -typical_matrices = zero_matrices + matrices + weak_stationary_matrices + strong_stationary_matrices + +non_strong_matrices = strictly_weak_matrices + non_weak_matrices +typical_matrices = zero_matrices + matrices + strong_matrices + non_strong_matrices scaled_matrices_2_plus_rows = [matrix for matrix in scaled_matrices if matrix.shape[0] >= 2] typical_matrices_2_plus_rows = [matrix for matrix in typical_matrices if matrix.shape[0] >= 2] diff --git a/tests/unit/aggregation/_property_testers.py b/tests/unit/aggregation/_property_testers.py index d4a52eadb..82cf29126 100644 --- a/tests/unit/aggregation/_property_testers.py +++ b/tests/unit/aggregation/_property_testers.py @@ -5,7 +5,7 @@ from torchjd.aggregation import Aggregator -from ._inputs import scaled_matrices, typical_matrices +from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices class ExpectedStructureProperty: @@ -101,4 +101,25 @@ def _assert_linear_under_scaling_property( x = aggregator(torch.diag(alpha * c1 + beta * c2) @ matrix) expected = alpha * x1 + beta * x2 - assert_close(x, expected, atol=8e-03, rtol=0) + assert_close(x, expected, atol=1e-02, rtol=0) + + +class StrongStationarityProperty: + """ + This class tests empirically that a given `Aggregator` is strongly stationary. + + An aggregator `A` is strongly stationary if for any matrix `J` with `A(J)=0`, `J` is strongly + stationary, i.e., there exists `0 None: + vector = aggregator(matrix) + norm = vector.norm().item() + assert norm > 1e-03 diff --git a/tests/unit/aggregation/test_cagrad.py b/tests/unit/aggregation/test_cagrad.py index a0315d048..eee7c1797 100644 --- a/tests/unit/aggregation/test_cagrad.py +++ b/tests/unit/aggregation/test_cagrad.py @@ -33,7 +33,7 @@ def test_equivalence_mean(matrix: Tensor): result = ca_grad(matrix) expected = mean(matrix) - assert_close(result, expected) + assert_close(result, expected, atol=2e-1, rtol=0) @mark.parametrize( diff --git a/tests/unit/aggregation/test_constant.py b/tests/unit/aggregation/test_constant.py index e1dca61a8..85f3837da 100644 --- a/tests/unit/aggregation/test_constant.py +++ b/tests/unit/aggregation/test_constant.py @@ -7,8 +7,12 @@ from torchjd.aggregation import Constant -from ._inputs import scaled_matrices, typical_matrices -from ._property_testers import ExpectedStructureProperty, LinearUnderScalingProperty +from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices +from ._property_testers import ( + ExpectedStructureProperty, + LinearUnderScalingProperty, + StrongStationarityProperty, +) # The weights must be a vector of length equal to the number of rows in the matrix that it will be # applied to. Thus, each `Constant` instance is specific to matrices of a given number of rows. To @@ -28,8 +32,13 @@ def _make_aggregator(matrix: Tensor) -> Constant: _matrices_2 = typical_matrices _aggregators_2 = [_make_aggregator(matrix) for matrix in _matrices_2] +_matrices_3 = non_strong_matrices +_aggregators_3 = [_make_aggregator(matrix) for matrix in _matrices_3] + -class TestConstant(ExpectedStructureProperty, LinearUnderScalingProperty): +class TestConstant( + ExpectedStructureProperty, LinearUnderScalingProperty, StrongStationarityProperty +): # Override the parametrization of `test_expected_structure_property` to make the test use the # right aggregator with each matrix. @@ -43,6 +52,11 @@ def test_expected_structure_property(cls, aggregator: Constant, matrix: Tensor): def test_linear_under_scaling_property(cls, aggregator: Constant, matrix: Tensor): cls._assert_linear_under_scaling_property(aggregator, matrix) + @classmethod + @mark.parametrize(["aggregator", "matrix"], zip(_aggregators_3, _matrices_3)) + def test_stationarity_property(cls, aggregator: Constant, matrix: Tensor): + cls._assert_stationarity_property(aggregator, matrix) + @mark.parametrize( ["weights_shape", "expectation"], diff --git a/tests/unit/aggregation/test_dualproj.py b/tests/unit/aggregation/test_dualproj.py index 2143ca359..7a6ce1d77 100644 --- a/tests/unit/aggregation/test_dualproj.py +++ b/tests/unit/aggregation/test_dualproj.py @@ -7,12 +7,16 @@ ExpectedStructureProperty, NonConflictingProperty, PermutationInvarianceProperty, + StrongStationarityProperty, ) @mark.parametrize("aggregator", [DualProj()]) class TestDualProj( - ExpectedStructureProperty, NonConflictingProperty, PermutationInvarianceProperty + ExpectedStructureProperty, + NonConflictingProperty, + PermutationInvarianceProperty, + StrongStationarityProperty, ): pass diff --git a/tests/unit/aggregation/test_mean.py b/tests/unit/aggregation/test_mean.py index 22860bcfb..3165201a2 100644 --- a/tests/unit/aggregation/test_mean.py +++ b/tests/unit/aggregation/test_mean.py @@ -6,12 +6,16 @@ ExpectedStructureProperty, LinearUnderScalingProperty, PermutationInvarianceProperty, + StrongStationarityProperty, ) @mark.parametrize("aggregator", [Mean()]) class TestMean( - ExpectedStructureProperty, PermutationInvarianceProperty, LinearUnderScalingProperty + ExpectedStructureProperty, + PermutationInvarianceProperty, + LinearUnderScalingProperty, + StrongStationarityProperty, ): pass diff --git a/tests/unit/aggregation/test_random.py b/tests/unit/aggregation/test_random.py index b8064b29f..cb90ad73c 100644 --- a/tests/unit/aggregation/test_random.py +++ b/tests/unit/aggregation/test_random.py @@ -2,11 +2,11 @@ from torchjd.aggregation import Random -from ._property_testers import ExpectedStructureProperty +from ._property_testers import ExpectedStructureProperty, StrongStationarityProperty @mark.parametrize("aggregator", [Random()]) -class TestRandom(ExpectedStructureProperty): +class TestRandom(ExpectedStructureProperty, StrongStationarityProperty): pass diff --git a/tests/unit/aggregation/test_sum.py b/tests/unit/aggregation/test_sum.py index 6a36d2f0e..c3c72c9bf 100644 --- a/tests/unit/aggregation/test_sum.py +++ b/tests/unit/aggregation/test_sum.py @@ -6,11 +6,17 @@ ExpectedStructureProperty, LinearUnderScalingProperty, PermutationInvarianceProperty, + StrongStationarityProperty, ) @mark.parametrize("aggregator", [Sum()]) -class TestSum(ExpectedStructureProperty, PermutationInvarianceProperty, LinearUnderScalingProperty): +class TestSum( + ExpectedStructureProperty, + PermutationInvarianceProperty, + LinearUnderScalingProperty, + StrongStationarityProperty, +): pass diff --git a/tests/unit/aggregation/test_upgrad.py b/tests/unit/aggregation/test_upgrad.py index 0b3cb7d25..f6e374616 100644 --- a/tests/unit/aggregation/test_upgrad.py +++ b/tests/unit/aggregation/test_upgrad.py @@ -8,6 +8,7 @@ LinearUnderScalingProperty, NonConflictingProperty, PermutationInvarianceProperty, + StrongStationarityProperty, ) @@ -17,6 +18,7 @@ class TestUPGrad( NonConflictingProperty, PermutationInvarianceProperty, LinearUnderScalingProperty, + StrongStationarityProperty, ): pass