Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
8fce23f
Remove tests of shape and rank of matrices
PierreQuinton Feb 23, 2025
c6b4798
Reimplements matrices generation, adds a distinction between strong a…
PierreQuinton Feb 23, 2025
d3bd73b
Improve naming of matrices
PierreQuinton Feb 23, 2025
7b4cfb7
Add testing of weak stationary matrices for the non-conflicting property
PierreQuinton Feb 23, 2025
4142861
make utility functions protected
PierreQuinton Feb 23, 2025
088c32f
Stationary matrices are now of rank exactly of rank=min(n_cols, n_row…
PierreQuinton Feb 24, 2025
ec6caa7
Remove [1,1] matrix from set of stationary matrices.
PierreQuinton Feb 24, 2025
9bbe4de
Replace torch.qr by torch.linalg.qr
ValerianRey Mar 18, 2025
a618862
Reorder functions
ValerianRey Mar 18, 2025
6054aaf
Fix docstring of _generate_orthogonal_matrix
ValerianRey Mar 18, 2025
e91a313
Fix docstring formatting
ValerianRey Mar 18, 2025
181f0bd
Fix docstrings of _generate_strong_stationary_matrix and _generate_we…
ValerianRey Mar 18, 2025
ba95e7a
Fix docstring formatting
ValerianRey Mar 18, 2025
974a45c
Fix docstring formatting
ValerianRey Mar 18, 2025
da8ef38
Fix docstring of _generate_matrix_with_orthogonal_vector
ValerianRey Mar 18, 2025
1c0546e
Simplify _complete_orthogonal_matrix
ValerianRey Mar 18, 2025
060a265
Fix names and docstrings
ValerianRey Mar 20, 2025
fe56822
Rename m to n_rows
ValerianRey Mar 20, 2025
82207da
Rename generated matrices
ValerianRey Mar 20, 2025
9475867
Rename m to n
ValerianRey Mar 20, 2025
b7ae077
Factorize generation of orthonormal matrices
ValerianRey Mar 20, 2025
8e995fe
Improve docstring formatting
ValerianRey Mar 20, 2025
200e262
Improve _generate_weak_stationary_matrix
ValerianRey Mar 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 63 additions & 117 deletions tests/unit/aggregation/_inputs.py
Original file line number Diff line number Diff line change
@@ -1,139 +1,80 @@
import torch
from torch import Tensor
from torch.nn.functional import normalize


def _check_valid_dimensions(n_rows: int, n_cols: int) -> None:
if n_rows < 1:
raise ValueError(
f"Parameter `n_rows` should be a positive integer. Found n_rows = {n_rows}."
)
if n_cols < 1:
raise ValueError(
f"Parameter `n_cols` should be a positive integer. Found n_cols = {n_cols}."
)


def _check_valid_rank(n_rows: int, n_cols: int, rank: int) -> None:
if rank < 0:
raise ValueError(f"Parameter `rank` should be a non-negative integer. Found rank = {rank}.")
if rank > n_rows:
raise ValueError(
"Parameter `rank` should not be larger than the number of rows. "
f"Found rank = {rank} and n_rows = {n_rows}."
)
if rank > n_cols:
raise ValueError(
"Parameter `rank` should not be larger than the number of columns. "
f"Found rank = {rank} and n_cols = {n_cols}."
)


def _augment_orthogonal_matrix(orthogonal_matrix: Tensor) -> Tensor:
"""
Augments the provided matrix with one more column that is filled with a random unit vector that
is orthogonal to the provided orthogonal_matrix.
"""
def _generate_matrix(m: int, n: int, rank: int) -> Tensor:
"""Generates a random matrix A of shape [m, n] with provided rank."""

n_rows = orthogonal_matrix.shape[0]
projection = orthogonal_matrix @ orthogonal_matrix.T
zero = torch.zeros([n_rows])
while True:
random_vector = torch.randn([n_rows])
projected_vector = random_vector - projection @ random_vector
if not torch.allclose(projected_vector, zero):
break
projected_vector = torch.nn.functional.normalize(projected_vector, dim=0).reshape([-1, 1])
augmented_matrix = torch.cat((orthogonal_matrix, projected_vector), dim=1)
return augmented_matrix
U = _generate_orthonormal_matrix(m)
Vt = _generate_orthonormal_matrix(n)
S = torch.diag(torch.abs(torch.randn([rank])))
A = U[:, :rank] @ S @ Vt[:rank, :]
return A


def _complete_orthogonal_matrix(orthogonal_matrix: Tensor, n_cols: int) -> Tensor:
def _generate_strong_stationary_matrix(m: int, n: int) -> Tensor:
"""
Iteratively augments the input ``orthogonal_matrix`` with columns that are orthogonal to its
existing columns, until it has the required number of columns. Returns the obtained
orthogonal matrix.
Generates a random matrix A of shape [m, n] with rank min(n, m - 1), such that there exists a
vector 0<v with v^T A = 0.
"""

if orthogonal_matrix.shape[1] > n_cols:
raise ValueError(
f"Parameter `n_cols` should exceed the second dimension of the provided matrix. Found "
f"`n_cols = {n_cols}` and `partial_matrix.shape[1] = {orthogonal_matrix.shape[1]}`."
)

for i in range(n_cols - 1):
orthogonal_matrix = _augment_orthogonal_matrix(orthogonal_matrix)
return orthogonal_matrix


def _generate_unitary_matrix(n_rows: int, n_cols: int) -> Tensor:
"""Generates a unitary matrix of shape [n_rows, n_cols]."""
v = torch.abs(torch.randn([m]))
return _generate_matrix_orthogonal_to_vector(v, n)

_check_valid_dimensions(n_rows, n_cols)
partial_matrix = torch.randn([n_rows, 1])
partial_matrix = torch.nn.functional.normalize(partial_matrix, dim=0)

unitary_matrix = _complete_orthogonal_matrix(partial_matrix, n_cols)
return unitary_matrix


def _generate_unitary_matrix_with_positive_column(n_rows: int, n_cols: int) -> Tensor:
def _generate_weak_stationary_matrix(m: int, n: int) -> Tensor:
"""
Generates a unitary matrix of shape [n_rows, n_cols] with the first column consisting of an all
positive vector.
Generates a random matrix A of shape [m, n] with rank min(n, m - 1), such that there exists a
vector 0<=v with one coordinate equal to 0 and such that v^T A = 0.

Note that if multiple coordinates of v were equal to 0, the generated matrix would still be weak
stationary, but here we only set one of them to 0 for simplicity.
"""
_check_valid_dimensions(n_rows, n_cols)
partial_matrix = torch.abs(torch.randn([n_rows, 1]))
partial_matrix = torch.nn.functional.normalize(partial_matrix, dim=0)

unitary_matrix_with_positive_column = _complete_orthogonal_matrix(partial_matrix, n_cols)
return unitary_matrix_with_positive_column
v = torch.abs(torch.randn([m]))
i = torch.randint(0, m, [])
v[i] = 0.0
return _generate_matrix_orthogonal_to_vector(v, n)


def _generate_diagonal_singular_values(rank: int) -> Tensor:
def _generate_matrix_orthogonal_to_vector(v: Tensor, n: int) -> Tensor:
"""
generates a diagonal matrix of positive values sorted in descending order.
Generates a random matrix A of shape [len(v), n] with rank min(n, len(v) - 1) such that
v^T A = 0.
"""
singular_values = torch.abs(torch.randn([rank]))
singular_values = torch.sort(singular_values, descending=True)[0]
S = torch.diag(singular_values)
return S


def generate_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
"""
Generates a random matrix of shape [``n_rows``, ``n_cols``] with provided ``rank``.
"""
rank = min(n, len(v) - 1)
Q = normalize(v, dim=0).unsqueeze(1)
U = _generate_semi_orthonormal_complement(Q)
Vt = _generate_orthonormal_matrix(n)
S = torch.diag(torch.abs(torch.randn([rank])))
A = U[:, :rank] @ S @ Vt[:rank, :]
return A

_check_valid_rank(n_rows, n_cols, rank)

if rank == 0:
matrix = torch.zeros([n_rows, n_cols])
else:
U = _generate_unitary_matrix(n_rows, rank)
V = _generate_unitary_matrix(n_cols, rank)
S = _generate_diagonal_singular_values(rank)
matrix = U @ S @ V.T
def _generate_orthonormal_matrix(dim: int) -> Tensor:
"""Uniformly generates a random orthonormal matrix of shape [dim, dim]."""

return matrix
return _generate_semi_orthonormal_complement(torch.zeros([dim, 0]))


def generate_stationary_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
def _generate_semi_orthonormal_complement(Q: Tensor) -> Tensor:
"""
Generates a random matrix of shape [``n_rows``, ``n_cols``] with provided ``rank``. The matrix
has a singular triple (u, s, v) such that u is all (strictly) positive and s is 0.
Uniformly generates a random semi-orthonormal matrix Q' (i.e. Q'^T Q' = I) of shape [m, m-k]
orthogonal to Q, i.e. such that the concatenation [Q, Q'] is an orthonormal matrix.

:param Q: A semi-orthonormal matrix (i.e. Q^T Q = I) of shape [m, k], with k <= m.
"""

_check_valid_rank(n_rows, n_cols, rank)
if rank == 0:
matrix = torch.zeros([n_rows, n_cols])
else:
U = _generate_unitary_matrix_with_positive_column(n_rows, rank)
V = _generate_unitary_matrix(n_cols, rank)
S = _generate_diagonal_singular_values(rank)
S[0, 0] = 0.0
matrix = U @ S @ V.T
m, k = Q.shape
A = torch.randn([m, m - k])

return matrix
# project A onto the orthogonal complement of Q
A_proj = A - Q @ (Q.T @ A)

Q_prime, _ = torch.linalg.qr(A_proj)
return Q_prime


_matrix_dimension_triples = [
Expand All @@ -145,27 +86,32 @@ def generate_stationary_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
(9, 11, 9),
]

_zero_rank_matrix_shapes = [
_zero_matrices_shapes = [
(1, 1),
(4, 3),
(9, 11),
]

_stationary_matrices_shapes = [
(5, 3),
(9, 11),
]

_scales = [0.0, 1e-10, 1.0, 1e3, 1e5, 1e10, 1e15]

# Fix seed to fix randomness of matrix generation
torch.manual_seed(0)

matrices = [
generate_matrix(n_rows, n_cols, rank) for n_rows, n_cols, rank in _matrix_dimension_triples
]
matrices = [_generate_matrix(m, n, rank) for m, n, rank in _matrix_dimension_triples]
scaled_matrices = [scale * matrix for scale in _scales for matrix in matrices]
zero_rank_matrices = [torch.zeros([n_rows, n_cols]) for n_rows, n_cols in _zero_rank_matrix_shapes]
matrices_2_plus_rows = [matrix for matrix in matrices + zero_rank_matrices if matrix.shape[0] >= 2]
zero_matrices = [torch.zeros([m, n]) for m, n in _zero_matrices_shapes]
matrices_2_plus_rows = [matrix for matrix in matrices + zero_matrices if matrix.shape[0] >= 2]
scaled_matrices_2_plus_rows = [
matrix for matrix in scaled_matrices + zero_rank_matrices if matrix.shape[0] >= 2
matrix for matrix in scaled_matrices + zero_matrices if matrix.shape[0] >= 2
]
strong_stationary_matrices = [
_generate_strong_stationary_matrix(m, n) for m, n in _stationary_matrices_shapes
]
stationary_matrices = [
generate_stationary_matrix(n_rows, n_cols, rank)
for n_rows, n_cols, rank in _matrix_dimension_triples
weak_stationary_matrices = strong_stationary_matrices + [
_generate_weak_stationary_matrix(m, n) for m, n in _stationary_matrices_shapes
]
6 changes: 3 additions & 3 deletions tests/unit/aggregation/_property_testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from torchjd.aggregation import Aggregator

from ._inputs import matrices, scaled_matrices, stationary_matrices, zero_rank_matrices
from ._inputs import matrices, scaled_matrices, weak_stationary_matrices, zero_matrices


class ExpectedStructureProperty:
Expand All @@ -17,7 +17,7 @@ class ExpectedStructureProperty:
"""

@classmethod
@mark.parametrize("matrix", scaled_matrices + zero_rank_matrices)
@mark.parametrize("matrix", scaled_matrices + zero_matrices)
def test_expected_structure_property(cls, aggregator: Aggregator, matrix: Tensor):
cls._assert_expected_structure_property(aggregator, matrix)

Expand All @@ -34,7 +34,7 @@ class NonConflictingProperty:
"""

@classmethod
@mark.parametrize("matrix", stationary_matrices + matrices)
@mark.parametrize("matrix", weak_stationary_matrices + matrices)
def test_non_conflicting_property(
cls,
aggregator: Aggregator,
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/aggregation/test_cagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from torchjd.aggregation import CAGrad, Mean

from ._inputs import matrices, stationary_matrices
from ._inputs import matrices, strong_stationary_matrices
from ._property_testers import ExpectedStructureProperty, NonConflictingProperty


Expand All @@ -20,7 +20,7 @@ class TestCAGradNonConflicting(NonConflictingProperty):
pass


@mark.parametrize("matrix", stationary_matrices + matrices)
@mark.parametrize("matrix", strong_stationary_matrices + matrices)
def test_equivalence_mean(matrix: Tensor):
"""Tests that CAGrad is equivalent to Mean when c=0."""

Expand Down
6 changes: 3 additions & 3 deletions tests/unit/aggregation/test_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from torchjd.aggregation import Constant

from ._inputs import matrices, scaled_matrices, stationary_matrices, zero_rank_matrices
from ._inputs import matrices, scaled_matrices, strong_stationary_matrices, zero_matrices
from ._property_testers import ExpectedStructureProperty

# The weights must be a vector of length equal to the number of rows in the matrix that it will be
Expand All @@ -19,10 +19,10 @@ def _make_aggregator(matrix: Tensor) -> Constant:
return Constant(weights)


_matrices_1 = scaled_matrices + zero_rank_matrices
_matrices_1 = scaled_matrices + zero_matrices
_aggregators_1 = [_make_aggregator(matrix) for matrix in _matrices_1]

_matrices_2 = matrices + stationary_matrices
_matrices_2 = matrices + strong_stationary_matrices
_aggregators_2 = [_make_aggregator(matrix) for matrix in _matrices_2]


Expand Down