Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
728413b
Make weak_stationary_matrices not contain strong_stationary_matrices
ValerianRey Mar 23, 2025
ea5a1f6
Add stationarity properties
ValerianRey Mar 23, 2025
e823026
Make TestMGDA extend WeakStationarityProperty
ValerianRey Mar 23, 2025
4f31b23
Make TestUPGrad extend StrongStationarityProperty
ValerianRey Mar 23, 2025
d588328
Fix formatting
ValerianRey Mar 23, 2025
8a34b2f
Simplify calls to parent class method
ValerianRey Mar 23, 2025
585da20
Add strong stationarity property to DualProj
PierreQuinton Mar 25, 2025
7b42859
Remove WeakStationarity property tester
PierreQuinton Mar 25, 2025
7ad85dd
Remove super class StationarityProperty
PierreQuinton Mar 25, 2025
100ef3f
Fix formatting
PierreQuinton Mar 25, 2025
fc83099
Merge branch 'main' into stationarity_property
PierreQuinton Mar 25, 2025
23a7de0
Merge branch 'main' into stationarity_property
PierreQuinton Mar 25, 2025
53c00eb
Improve docstring of StrongStationarityProperty
PierreQuinton Mar 25, 2025
aac99b4
Fix StrongStationarityProperty to be only one direction of the equiva…
ValerianRey Mar 25, 2025
1ec464a
Fix docstring
ValerianRey Mar 25, 2025
b96a9a1
Merge branch 'main' into stationarity_property
ValerianRey Mar 26, 2025
26fc65a
Improve stationary matrices generation
PierreQuinton Mar 26, 2025
6e84fec
Use correct stationary matrices in strong stationary property tester
PierreQuinton Mar 26, 2025
bd2476d
Add strong stationarity property to constant, mean, random and sum
PierreQuinton Mar 26, 2025
11c2c2c
fixes
PierreQuinton Mar 26, 2025
e2eb2c7
Merge branch 'main' into stationarity_property
ValerianRey Mar 27, 2025
b8a6dc9
Fix docstrings of generation functions
ValerianRey Mar 27, 2025
84ec3a4
Fix split_index in _generate_weak_non_strong_stationary_matrix
ValerianRey Mar 27, 2025
02ffaf6
Add comment to _generate_non_weak_stationary_matrix
ValerianRey Mar 27, 2025
5bd8af3
Rename v to z in _generate_weak_non_strong_stationary_matrix
ValerianRey Mar 27, 2025
e12cc9b
Add assertions on m and rank in _generate_weak_non_strong_stationary_…
ValerianRey Mar 27, 2025
797d391
Add comment in _generate_weak_non_strong_stationary_matrix
ValerianRey Mar 27, 2025
bf6b822
Add assertions on m and rank in _generate_strong_stationary_matrix
ValerianRey Mar 27, 2025
a80c33f
Add comment to _generate_strong_stationary_matrix
ValerianRey Mar 27, 2025
58794ab
Increase tolerance of LUS
ValerianRey Mar 27, 2025
e09b4c6
Increase tolerance of test_equivalence_mean
ValerianRey Mar 27, 2025
37a73f5
Fix rank assertion
ValerianRey Mar 27, 2025
e328378
Simplify inputs generation code
ValerianRey Mar 27, 2025
ee27d3d
Add failing SSProperty to AMTL CAGRAD and MGDA
ValerianRey Mar 27, 2025
30af731
Revert "Add failing SSProperty to AMTL CAGRAD and MGDA"
ValerianRey Mar 27, 2025
0a26c13
Improve docstring of StrongStationarityProperty
ValerianRey Mar 27, 2025
a5205f3
Simplify _generate_strong_matrix
ValerianRey Mar 27, 2025
f9cefc3
Clean up comment of _generate_strong_matrix
ValerianRey Mar 27, 2025
7fb5462
Rename z to u
ValerianRey Mar 27, 2025
72153a1
Clarify the implementation of _generate_strictly_weak_matrix
ValerianRey Mar 27, 2025
6918c68
Remove comment in _generate_strictly_weak_matrix
ValerianRey Mar 27, 2025
c936ac2
Rename v to u in _generate_non_weak_matrix
ValerianRey Mar 27, 2025
4334bc1
Remove comment in _generate_non_weak_matrix
ValerianRey Mar 27, 2025
2cb5823
Fix split_index in _generate_strictly_weak_matrix
ValerianRey Mar 27, 2025
e2ab11e
Add docstrings
ValerianRey Mar 27, 2025
cc6bdb3
Add assert in _generate_non_weak_matrix
ValerianRey Mar 27, 2025
ef75809
Merge dims into _stationarity_dims
ValerianRey Mar 27, 2025
e261e79
Adapt atol of test_equivalence_mean
ValerianRey Mar 27, 2025
847c1d6
Reduce dimensions in _stationarity_dims
ValerianRey Mar 27, 2025
3796e55
Rename _generate to _sample
ValerianRey Mar 27, 2025
b9c6bc6
Rename _generate to _sample v2
ValerianRey Mar 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 92 additions & 47 deletions tests/unit/aggregation/_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,65 +3,107 @@
from torch.nn.functional import normalize


def _generate_matrix(m: int, n: int, rank: int) -> Tensor:
"""Generates a random matrix A of shape [m, n] with provided rank."""
def _sample_matrix(m: int, n: int, rank: int) -> Tensor:
"""Samples a random matrix A of shape [m, n] with provided rank."""

U = _generate_orthonormal_matrix(m)
Vt = _generate_orthonormal_matrix(n)
U = _sample_orthonormal_matrix(m)
Vt = _sample_orthonormal_matrix(n)
S = torch.diag(torch.abs(torch.randn([rank])))
A = U[:, :rank] @ S @ Vt[:rank, :]
return A


def _generate_strong_stationary_matrix(m: int, n: int) -> Tensor:
def _sample_strong_matrix(m: int, n: int, rank: int) -> Tensor:
"""
Generates a random matrix A of shape [m, n] with rank min(n, m - 1), such that there exists a
vector 0<v with v^T A = 0.
Samples a random strongly stationary matrix A of shape [m, n] with provided rank.

Definition: A matrix A is said to be strongly stationary if there exists a vector 0 < v such
that v^T A = 0.

This is done by sampling a positive v, and by then sampling a matrix orthogonal to v.
"""

assert 1 < m
assert 0 < rank <= min(m - 1, n)

v = torch.abs(torch.randn([m]))
return _generate_matrix_orthogonal_to_vector(v, n)
U1 = normalize(v, dim=0).unsqueeze(1)
U2 = _sample_semi_orthonormal_complement(U1)
Vt = _sample_orthonormal_matrix(n)
S = torch.diag(torch.abs(torch.randn([rank])))
A = U2[:, :rank] @ S @ Vt[:rank, :]
return A


def _generate_weak_stationary_matrix(m: int, n: int) -> Tensor:
def _sample_strictly_weak_matrix(m: int, n: int, rank: int) -> Tensor:
"""
Generates a random matrix A of shape [m, n] with rank min(n, m - 1), such that there exists a
vector 0<=v with one coordinate equal to 0 and such that v^T A = 0.
Samples a random strictly weakly stationary matrix A of shape [m, n] with provided rank.

Definition: A matrix A is said to be weakly stationary if there exists a vector 0 <= v, v != 0,
such that v^T A = 0.

Note that if multiple coordinates of v were equal to 0, the generated matrix would still be weak
stationary, but here we only set one of them to 0 for simplicity.
Definition: A matrix A is said to be strictly weakly stationary if it is weakly stationary and
not strongly stationary, i.e. if there exists a vector 0 <= v, v != 0, such that v^T A = 0 and
there exists no vector 0 < w with w^T A = 0.

This is done by sampling two unit-norm vectors v, v', whose sum u is a positive vector. These
two vectors are also non-negative and non-zero, and are furthermore orthogonal. Then, a matrix
A, orthogonal to v, is sampled. By its orthogonality to v, A is weakly stationary. Moreover,
since v' is a non-negative left-singular vector of A with positive singular value s, any 0 < w
satisfies w^T A != 0. Otherwise, we would have 0 = w^T A A^T v' = s w^T v' > 0, which is a
contradiction. A is thus also not strongly stationary.
"""

v = torch.abs(torch.randn([m]))
i = torch.randint(0, m, [])
v[i] = 0.0
return _generate_matrix_orthogonal_to_vector(v, n)
assert 1 < m
assert 0 < rank <= min(m - 1, n)

u = torch.abs(torch.randn([m]))
split_index = torch.randint(1, m, []).item()
shuffled_range = torch.randperm(m)
v = torch.zeros(m)
v[shuffled_range[:split_index]] = normalize(u[shuffled_range[:split_index]], dim=0)
v_prime = torch.zeros(m)
v_prime[shuffled_range[split_index:]] = normalize(u[shuffled_range[split_index:]], dim=0)
U1 = torch.stack([v, v_prime]).T
U2 = _sample_semi_orthonormal_complement(U1)
U = torch.hstack([U1, U2])
Vt = _sample_orthonormal_matrix(n)
S = torch.diag(torch.abs(torch.randn([rank])))
A = U[:, 1 : rank + 1] @ S @ Vt[:rank, :]
return A


def _generate_matrix_orthogonal_to_vector(v: Tensor, n: int) -> Tensor:
def _sample_non_weak_matrix(m: int, n: int, rank: int) -> Tensor:
"""
Generates a random matrix A of shape [len(v), n] with rank min(n, len(v) - 1) such that
v^T A = 0.
Samples a random non weakly-stationary matrix A of shape [m, n] with provided rank.

This is done by sampling a positive u, and by then sampling a matrix A that has u as one of its
left-singular vectors, with positive singular value s. Any 0 <= v, v != 0, satisfies v^T A != 0.
Otherwise, we would have 0 = v^T A A^T u = s v^T u > 0, which is a contradiction. A is thus not
weakly stationary.
"""

rank = min(n, len(v) - 1)
Q = normalize(v, dim=0).unsqueeze(1)
U = _generate_semi_orthonormal_complement(Q)
Vt = _generate_orthonormal_matrix(n)
assert 0 < rank <= min(m, n)

u = torch.abs(torch.randn([m]))
U1 = normalize(u, dim=0).unsqueeze(1)
U2 = _sample_semi_orthonormal_complement(U1)
U = torch.hstack([U1, U2])
Vt = _sample_orthonormal_matrix(n)
S = torch.diag(torch.abs(torch.randn([rank])))
A = U[:, :rank] @ S @ Vt[:rank, :]
return A


def _generate_orthonormal_matrix(dim: int) -> Tensor:
"""Uniformly generates a random orthonormal matrix of shape [dim, dim]."""
def _sample_orthonormal_matrix(dim: int) -> Tensor:
"""Uniformly samples a random orthonormal matrix of shape [dim, dim]."""

return _generate_semi_orthonormal_complement(torch.zeros([dim, 0]))
return _sample_semi_orthonormal_complement(torch.zeros([dim, 0]))


def _generate_semi_orthonormal_complement(Q: Tensor) -> Tensor:
def _sample_semi_orthonormal_complement(Q: Tensor) -> Tensor:
"""
Uniformly generates a random semi-orthonormal matrix Q' (i.e. Q'^T Q' = I) of shape [m, m-k]
Uniformly samples a random semi-orthonormal matrix Q' (i.e. Q'^T Q' = I) of shape [m, m-k]
orthogonal to Q, i.e. such that the concatenation [Q, Q'] is an orthonormal matrix.

:param Q: A semi-orthonormal matrix (i.e. Q^T Q = I) of shape [m, k], with k <= m.
Expand All @@ -77,7 +119,7 @@ def _generate_semi_orthonormal_complement(Q: Tensor) -> Tensor:
return Q_prime


_matrix_dimension_triples = [
_normal_dims = [
(1, 1, 1),
(4, 3, 1),
(4, 3, 2),
Expand All @@ -86,32 +128,35 @@ def _generate_semi_orthonormal_complement(Q: Tensor) -> Tensor:
(9, 11, 9),
]

_zero_matrices_shapes = [
(1, 1),
(4, 3),
(9, 11),
_zero_dims = [
(1, 1, 0),
(4, 3, 0),
(9, 11, 0),
]

_stationary_matrices_shapes = [
(5, 3),
(9, 11),
_stationarity_dims = [
(20, 10, 10),
(20, 10, 5),
(20, 10, 1),
(20, 100, 1),
(20, 100, 19),
]

_scales = [0.0, 1e-10, 1e3, 1e5, 1e10, 1e15]

# Fix seed to fix randomness of matrix generation
# Fix seed to fix randomness of matrix sampling
torch.manual_seed(0)

matrices = [_generate_matrix(m, n, rank) for m, n, rank in _matrix_dimension_triples]
matrices = [_sample_matrix(m, n, r) for m, n, r in _normal_dims]
zero_matrices = [torch.zeros([m, n]) for m, n, _ in _zero_dims]
strong_matrices = [_sample_strong_matrix(m, n, r) for m, n, r in _stationarity_dims]
strictly_weak_matrices = [_sample_strictly_weak_matrix(m, n, r) for m, n, r in _stationarity_dims]
non_weak_matrices = [_sample_non_weak_matrix(m, n, r) for m, n, r in _stationarity_dims]

scaled_matrices = [scale * matrix for scale in _scales for matrix in matrices]
zero_matrices = [torch.zeros([m, n]) for m, n in _zero_matrices_shapes]
strong_stationary_matrices = [
_generate_strong_stationary_matrix(m, n) for m, n in _stationary_matrices_shapes
]
weak_stationary_matrices = [
_generate_weak_stationary_matrix(m, n) for m, n in _stationary_matrices_shapes
]
typical_matrices = zero_matrices + matrices + weak_stationary_matrices + strong_stationary_matrices

non_strong_matrices = strictly_weak_matrices + non_weak_matrices
typical_matrices = zero_matrices + matrices + strong_matrices + non_strong_matrices

scaled_matrices_2_plus_rows = [matrix for matrix in scaled_matrices if matrix.shape[0] >= 2]
typical_matrices_2_plus_rows = [matrix for matrix in typical_matrices if matrix.shape[0] >= 2]
25 changes: 23 additions & 2 deletions tests/unit/aggregation/_property_testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from torchjd.aggregation import Aggregator

from ._inputs import scaled_matrices, typical_matrices
from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices


class ExpectedStructureProperty:
Expand Down Expand Up @@ -101,4 +101,25 @@ def _assert_linear_under_scaling_property(
x = aggregator(torch.diag(alpha * c1 + beta * c2) @ matrix)
expected = alpha * x1 + beta * x2

assert_close(x, expected, atol=8e-03, rtol=0)
assert_close(x, expected, atol=1e-02, rtol=0)


class StrongStationarityProperty:
"""
This class tests empirically that a given `Aggregator` is strongly stationary.

An aggregator `A` is strongly stationary if for any matrix `J` with `A(J)=0`, `J` is strongly
stationary, i.e., there exists `0<w` such that `J^T w=0`. In this class, we test the
contraposition: whenever `J` is not strongly stationary, we must have `A(J) != 0`.
"""

@classmethod
@mark.parametrize("matrix", non_strong_matrices)
def test_stationarity_property(cls, aggregator: Aggregator, matrix: Tensor):
cls._assert_stationarity_property(aggregator, matrix)

@staticmethod
def _assert_stationarity_property(aggregator: Aggregator, matrix: Tensor) -> None:
vector = aggregator(matrix)
norm = vector.norm().item()
assert norm > 1e-03
2 changes: 1 addition & 1 deletion tests/unit/aggregation/test_cagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_equivalence_mean(matrix: Tensor):
result = ca_grad(matrix)
expected = mean(matrix)

assert_close(result, expected)
assert_close(result, expected, atol=2e-1, rtol=0)


@mark.parametrize(
Expand Down
20 changes: 17 additions & 3 deletions tests/unit/aggregation/test_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@

from torchjd.aggregation import Constant

from ._inputs import scaled_matrices, typical_matrices
from ._property_testers import ExpectedStructureProperty, LinearUnderScalingProperty
from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices
from ._property_testers import (
ExpectedStructureProperty,
LinearUnderScalingProperty,
StrongStationarityProperty,
)

# The weights must be a vector of length equal to the number of rows in the matrix that it will be
# applied to. Thus, each `Constant` instance is specific to matrices of a given number of rows. To
Expand All @@ -28,8 +32,13 @@ def _make_aggregator(matrix: Tensor) -> Constant:
_matrices_2 = typical_matrices
_aggregators_2 = [_make_aggregator(matrix) for matrix in _matrices_2]

_matrices_3 = non_strong_matrices
_aggregators_3 = [_make_aggregator(matrix) for matrix in _matrices_3]


class TestConstant(ExpectedStructureProperty, LinearUnderScalingProperty):
class TestConstant(
ExpectedStructureProperty, LinearUnderScalingProperty, StrongStationarityProperty
):
# Override the parametrization of `test_expected_structure_property` to make the test use the
# right aggregator with each matrix.

Expand All @@ -43,6 +52,11 @@ def test_expected_structure_property(cls, aggregator: Constant, matrix: Tensor):
def test_linear_under_scaling_property(cls, aggregator: Constant, matrix: Tensor):
cls._assert_linear_under_scaling_property(aggregator, matrix)

@classmethod
@mark.parametrize(["aggregator", "matrix"], zip(_aggregators_3, _matrices_3))
def test_stationarity_property(cls, aggregator: Constant, matrix: Tensor):
cls._assert_stationarity_property(aggregator, matrix)


@mark.parametrize(
["weights_shape", "expectation"],
Expand Down
6 changes: 5 additions & 1 deletion tests/unit/aggregation/test_dualproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@
ExpectedStructureProperty,
NonConflictingProperty,
PermutationInvarianceProperty,
StrongStationarityProperty,
)


@mark.parametrize("aggregator", [DualProj()])
class TestDualProj(
ExpectedStructureProperty, NonConflictingProperty, PermutationInvarianceProperty
ExpectedStructureProperty,
NonConflictingProperty,
PermutationInvarianceProperty,
StrongStationarityProperty,
):
pass

Expand Down
6 changes: 5 additions & 1 deletion tests/unit/aggregation/test_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@
ExpectedStructureProperty,
LinearUnderScalingProperty,
PermutationInvarianceProperty,
StrongStationarityProperty,
)


@mark.parametrize("aggregator", [Mean()])
class TestMean(
ExpectedStructureProperty, PermutationInvarianceProperty, LinearUnderScalingProperty
ExpectedStructureProperty,
PermutationInvarianceProperty,
LinearUnderScalingProperty,
StrongStationarityProperty,
):
pass

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/aggregation/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from torchjd.aggregation import Random

from ._property_testers import ExpectedStructureProperty
from ._property_testers import ExpectedStructureProperty, StrongStationarityProperty


@mark.parametrize("aggregator", [Random()])
class TestRandom(ExpectedStructureProperty):
class TestRandom(ExpectedStructureProperty, StrongStationarityProperty):
pass


Expand Down
8 changes: 7 additions & 1 deletion tests/unit/aggregation/test_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@
ExpectedStructureProperty,
LinearUnderScalingProperty,
PermutationInvarianceProperty,
StrongStationarityProperty,
)


@mark.parametrize("aggregator", [Sum()])
class TestSum(ExpectedStructureProperty, PermutationInvarianceProperty, LinearUnderScalingProperty):
class TestSum(
ExpectedStructureProperty,
PermutationInvarianceProperty,
LinearUnderScalingProperty,
StrongStationarityProperty,
):
pass


Expand Down
2 changes: 2 additions & 0 deletions tests/unit/aggregation/test_upgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
LinearUnderScalingProperty,
NonConflictingProperty,
PermutationInvarianceProperty,
StrongStationarityProperty,
)


Expand All @@ -17,6 +18,7 @@ class TestUPGrad(
NonConflictingProperty,
PermutationInvarianceProperty,
LinearUnderScalingProperty,
StrongStationarityProperty,
):
pass

Expand Down