Skip to content

Commit 20c2b41

Browse files
committed
Add tests for scaling invariance of _project_weights
1 parent 37d0813 commit 20c2b41

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

tests/unit/aggregation/test_dual_cone_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,23 @@ def test_solution_weights(shape: tuple[int, int]):
5151
assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0)
5252

5353

54+
@mark.parametrize("shape", [(5, 7), (9, 37), (32, 114)])
55+
@mark.parametrize("scaling", [0.25, 0.5, 4.0, 16.0])
56+
def test_scale_invariant(shape: tuple[int, int], scaling: float):
57+
"""
58+
Tests that `_project_weights` is invariant under scaling.
59+
"""
60+
61+
J = torch.randn(shape)
62+
G = J @ J.T
63+
u = torch.rand(shape[0])
64+
65+
w = _project_weights(u, G, "quadprog")
66+
w_scaled = _project_weights(u, scaling * G, "quadprog")
67+
68+
assert_close(w_scaled, w)
69+
70+
5471
@mark.parametrize("shape", [(5, 2, 3), (1, 3, 6, 9), (2, 1, 1, 5, 8), (3, 1)])
5572
def test_tensorization_shape(shape: tuple[int, ...]):
5673
"""

0 commit comments

Comments
 (0)