Skip to content

Commit 4fd394d

Browse files
authored
test(aggregation): Refactor property testing (#339)
* Transform property testers in _property_testers.py into assertion functions in _asserts.py * Add n_runs parameter to assert_linear_under_scaling * Update aggregator tests to parameterize and use the new assertion functions rather than property testers * Fix dtype issues * Add test_permutation_invariant to ConFIG and IMTLG * Update tolerances * Refactor matrix sampling functions into classes and move them to _matrix_samplers.py * Add dtype parameter to the constructor of MatrixSamplers * Add rng parameter to the call method of MatrixSamplers Note that we decided to have default values for some parameters of the assertion functions (atol, rtol, n_runs and threshold), so that we still have the benefit of having most aggregators tested with the same tolerances, for ease of maintenance (as we don't want to update each aggregator's tolerance individually each time inputs, hardware or cuda drivers change). We have decided that only UPGrad and DualProj would use custom (and tight) tolerances, so that we're aware if we make a regression for them. For some reason, the permutation_invariance tests now pass on cuda for ConFIG and IMTLG, so they now test it too. This could be due to a change in the cuda implementation of pinv.
1 parent 1443dcd commit 4fd394d

20 files changed

+627
-453
lines changed

tests/unit/aggregation/_asserts.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import torch
2+
from pytest import raises
3+
from torch import Tensor
4+
from torch.testing import assert_close
5+
6+
from torchjd.aggregation import Aggregator
7+
from torchjd.aggregation._utils.non_differentiable import NonDifferentiableError
8+
9+
10+
def assert_expected_structure(aggregator: Aggregator, matrix: Tensor) -> None:
11+
"""
12+
Tests that the vector returned by the `__call__` method of an `Aggregator` has the expected
13+
structure: it should return a vector whose dimension should be the number of columns of the
14+
input matrix, and that should only contain finite values (no `nan`, `inf` or `-inf`). Note that
15+
this property implies that the `__call__` method does not raise any exception.
16+
"""
17+
18+
vector = aggregator(matrix) # Will fail if the call raises an exception
19+
assert vector.shape == matrix.shape[1:]
20+
assert vector.isfinite().all()
21+
22+
23+
def assert_non_conflicting(
24+
aggregator: Aggregator, matrix: Tensor, atol: float = 4e-04, rtol: float = 4e-04
25+
) -> None:
26+
"""Tests empirically that a given `Aggregator` satisfies the non-conflicting property."""
27+
28+
vector = aggregator(matrix)
29+
output_direction = matrix @ vector
30+
positive_directions = output_direction[output_direction >= 0]
31+
assert_close(positive_directions.norm(), output_direction.norm(), atol=atol, rtol=rtol)
32+
33+
34+
def assert_permutation_invariant(
35+
aggregator: Aggregator,
36+
matrix: Tensor,
37+
n_runs: int = 5,
38+
atol: float = 1e-04,
39+
rtol: float = 1e-04,
40+
) -> None:
41+
"""
42+
Tests empirically that for a given `Aggregator`, randomly permuting rows of the input matrix
43+
doesn't change the aggregation.
44+
"""
45+
46+
def permute_randomly(matrix_: Tensor) -> Tensor:
47+
row_permutation = torch.randperm(matrix_.size(dim=0))
48+
return matrix_[row_permutation]
49+
50+
vector = aggregator(matrix)
51+
52+
for _ in range(n_runs):
53+
permuted_matrix = permute_randomly(matrix)
54+
permuted_vector = aggregator(permuted_matrix)
55+
56+
assert_close(vector, permuted_vector, atol=atol, rtol=rtol)
57+
58+
59+
def assert_linear_under_scaling(
60+
aggregator: Aggregator,
61+
matrix: Tensor,
62+
n_runs: int = 5,
63+
atol: float = 1e-04,
64+
rtol: float = 1e-04,
65+
) -> None:
66+
"""Tests empirically that a given `Aggregator` satisfies the linear under scaling property."""
67+
68+
for _ in range(n_runs):
69+
c1 = torch.rand(matrix.shape[0], dtype=matrix.dtype)
70+
c2 = torch.rand(matrix.shape[0], dtype=matrix.dtype)
71+
alpha = torch.rand([], dtype=matrix.dtype)
72+
beta = torch.rand([], dtype=matrix.dtype)
73+
74+
x1 = aggregator(torch.diag(c1) @ matrix)
75+
x2 = aggregator(torch.diag(c2) @ matrix)
76+
x = aggregator(torch.diag(alpha * c1 + beta * c2) @ matrix)
77+
expected = alpha * x1 + beta * x2
78+
79+
assert_close(x, expected, atol=atol, rtol=rtol)
80+
81+
82+
def assert_strongly_stationary(
83+
aggregator: Aggregator, matrix: Tensor, threshold: float = 5e-03
84+
) -> None:
85+
"""
86+
Tests empirically that a given `Aggregator` is strongly stationary.
87+
88+
An aggregator `A` is strongly stationary if for any matrix `J` with `A(J)=0`, `J` is strongly
89+
stationary, i.e., there exists `0<w` such that `J^T w=0`. In this class, we test the
90+
contraposition: whenever `J` is not strongly stationary, we must have `A(J) != 0`.
91+
"""
92+
93+
vector = aggregator(matrix)
94+
norm = vector.norm().item()
95+
assert norm > threshold
96+
97+
98+
def assert_non_differentiable(aggregator: Aggregator, matrix: Tensor):
99+
"""
100+
Tests empirically that a given non-differentiable `Aggregator` correctly raises a
101+
NonDifferentiableError whenever we try to backward through it.
102+
"""
103+
104+
vector = aggregator(matrix)
105+
with raises(NonDifferentiableError):
106+
vector.backward(torch.ones_like(vector))

tests/unit/aggregation/_inputs.py

Lines changed: 8 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -1,123 +1,7 @@
11
import torch
2-
from torch import Tensor
3-
from torch.nn.functional import normalize
4-
5-
6-
def _sample_matrix(m: int, n: int, rank: int) -> Tensor:
7-
"""Samples a random matrix A of shape [m, n] with provided rank."""
8-
9-
U = _sample_orthonormal_matrix(m)
10-
Vt = _sample_orthonormal_matrix(n)
11-
S = torch.diag(torch.abs(torch.randn([rank])))
12-
A = U[:, :rank] @ S @ Vt[:rank, :]
13-
return A
14-
15-
16-
def _sample_strong_matrix(m: int, n: int, rank: int) -> Tensor:
17-
"""
18-
Samples a random strongly stationary matrix A of shape [m, n] with provided rank.
19-
20-
Definition: A matrix A is said to be strongly stationary if there exists a vector 0 < v such
21-
that v^T A = 0.
22-
23-
This is done by sampling a positive v, and by then sampling a matrix orthogonal to v.
24-
"""
25-
26-
assert 1 < m
27-
assert 0 < rank <= min(m - 1, n)
28-
29-
v = torch.abs(torch.randn([m]))
30-
U1 = normalize(v, dim=0).unsqueeze(1)
31-
U2 = _sample_semi_orthonormal_complement(U1)
32-
Vt = _sample_orthonormal_matrix(n)
33-
S = torch.diag(torch.abs(torch.randn([rank])))
34-
A = U2[:, :rank] @ S @ Vt[:rank, :]
35-
return A
36-
37-
38-
def _sample_strictly_weak_matrix(m: int, n: int, rank: int) -> Tensor:
39-
"""
40-
Samples a random strictly weakly stationary matrix A of shape [m, n] with provided rank.
41-
42-
Definition: A matrix A is said to be weakly stationary if there exists a vector 0 <= v, v != 0,
43-
such that v^T A = 0.
44-
45-
Definition: A matrix A is said to be strictly weakly stationary if it is weakly stationary and
46-
not strongly stationary, i.e. if there exists a vector 0 <= v, v != 0, such that v^T A = 0 and
47-
there exists no vector 0 < w with w^T A = 0.
48-
49-
This is done by sampling two unit-norm vectors v, v', whose sum u is a positive vector. These
50-
two vectors are also non-negative and non-zero, and are furthermore orthogonal. Then, a matrix
51-
A, orthogonal to v, is sampled. By its orthogonality to v, A is weakly stationary. Moreover,
52-
since v' is a non-negative left-singular vector of A with positive singular value s, any 0 < w
53-
satisfies w^T A != 0. Otherwise, we would have 0 = w^T A A^T v' = s w^T v' > 0, which is a
54-
contradiction. A is thus also not strongly stationary.
55-
"""
56-
57-
assert 1 < m
58-
assert 0 < rank <= min(m - 1, n)
59-
60-
u = torch.abs(torch.randn([m]))
61-
split_index = torch.randint(1, m, []).item()
62-
shuffled_range = torch.randperm(m)
63-
v = torch.zeros(m)
64-
v[shuffled_range[:split_index]] = normalize(u[shuffled_range[:split_index]], dim=0)
65-
v_prime = torch.zeros(m)
66-
v_prime[shuffled_range[split_index:]] = normalize(u[shuffled_range[split_index:]], dim=0)
67-
U1 = torch.stack([v, v_prime]).T
68-
U2 = _sample_semi_orthonormal_complement(U1)
69-
U = torch.hstack([U1, U2])
70-
Vt = _sample_orthonormal_matrix(n)
71-
S = torch.diag(torch.abs(torch.randn([rank])))
72-
A = U[:, 1 : rank + 1] @ S @ Vt[:rank, :]
73-
return A
74-
75-
76-
def _sample_non_weak_matrix(m: int, n: int, rank: int) -> Tensor:
77-
"""
78-
Samples a random non weakly-stationary matrix A of shape [m, n] with provided rank.
79-
80-
This is done by sampling a positive u, and by then sampling a matrix A that has u as one of its
81-
left-singular vectors, with positive singular value s. Any 0 <= v, v != 0, satisfies v^T A != 0.
82-
Otherwise, we would have 0 = v^T A A^T u = s v^T u > 0, which is a contradiction. A is thus not
83-
weakly stationary.
84-
"""
85-
86-
assert 0 < rank <= min(m, n)
87-
88-
u = torch.abs(torch.randn([m]))
89-
U1 = normalize(u, dim=0).unsqueeze(1)
90-
U2 = _sample_semi_orthonormal_complement(U1)
91-
U = torch.hstack([U1, U2])
92-
Vt = _sample_orthonormal_matrix(n)
93-
S = torch.diag(torch.abs(torch.randn([rank])))
94-
A = U[:, :rank] @ S @ Vt[:rank, :]
95-
return A
96-
97-
98-
def _sample_orthonormal_matrix(dim: int) -> Tensor:
99-
"""Uniformly samples a random orthonormal matrix of shape [dim, dim]."""
100-
101-
return _sample_semi_orthonormal_complement(torch.zeros([dim, 0]))
102-
103-
104-
def _sample_semi_orthonormal_complement(Q: Tensor) -> Tensor:
105-
"""
106-
Uniformly samples a random semi-orthonormal matrix Q' (i.e. Q'^T Q' = I) of shape [m, m-k]
107-
orthogonal to Q, i.e. such that the concatenation [Q, Q'] is an orthonormal matrix.
108-
109-
:param Q: A semi-orthonormal matrix (i.e. Q^T Q = I) of shape [m, k], with k <= m.
110-
"""
111-
112-
m, k = Q.shape
113-
A = torch.randn([m, m - k])
114-
115-
# project A onto the orthogonal complement of Q
116-
A_proj = A - Q @ (Q.T @ A)
117-
118-
Q_prime, _ = torch.linalg.qr(A_proj)
119-
return Q_prime
2+
from unit.conftest import DEVICE
1203

4+
from ._matrix_samplers import NonWeakSampler, NormalSampler, StrictlyWeakSampler, StrongSampler
1215

1226
_normal_dims = [
1237
(1, 1, 1),
@@ -144,14 +28,13 @@ def _sample_semi_orthonormal_complement(Q: Tensor) -> Tensor:
14428

14529
_scales = [0.0, 1e-10, 1e3, 1e5, 1e10, 1e15]
14630

147-
# Fix seed to fix randomness of matrix sampling
148-
torch.manual_seed(0)
31+
_rng = torch.Generator(device=DEVICE).manual_seed(0)
14932

150-
matrices = [_sample_matrix(m, n, r) for m, n, r in _normal_dims]
33+
matrices = [NormalSampler(m, n, r)(_rng) for m, n, r in _normal_dims]
15134
zero_matrices = [torch.zeros([m, n]) for m, n, _ in _zero_dims]
152-
strong_matrices = [_sample_strong_matrix(m, n, r) for m, n, r in _stationarity_dims]
153-
strictly_weak_matrices = [_sample_strictly_weak_matrix(m, n, r) for m, n, r in _stationarity_dims]
154-
non_weak_matrices = [_sample_non_weak_matrix(m, n, r) for m, n, r in _stationarity_dims]
35+
strong_matrices = [StrongSampler(m, n, r)(_rng) for m, n, r in _stationarity_dims]
36+
strictly_weak_matrices = [StrictlyWeakSampler(m, n, r)(_rng) for m, n, r in _stationarity_dims]
37+
non_weak_matrices = [NonWeakSampler(m, n, r)(_rng) for m, n, r in _stationarity_dims]
15538

15639
scaled_matrices = [scale * matrix for scale in _scales for matrix in matrices]
15740

@@ -170,4 +53,4 @@ def _sample_semi_orthonormal_complement(Q: Tensor) -> Tensor:
17053
(9, 11, 5),
17154
(9, 11, 9),
17255
]
173-
nash_mtl_matrices = [_sample_matrix(m, n, r) for m, n, r in _nashmtl_dims]
56+
nash_mtl_matrices = [NormalSampler(m, n, r)(_rng) for m, n, r in _nashmtl_dims]

0 commit comments

Comments
 (0)