Skip to content

Commit 3ff9a8b

Browse files
authored
test: Fix some test parametrizations (#576)
* Fix wrong type hint in test_weighting_output * Fix pytest_make_parametrize_id in cases where the string has a newline * Fix test_compute_partial_gramian parametrization to be deterministic
1 parent 6c71230 commit 3ff9a8b

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

tests/conftest.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ def pytest_make_parametrize_id(config, val, argname):
6464
elif isinstance(val, nullcontext):
6565
optional_string = "does_not_raise()"
6666

67-
if isinstance(optional_string, str) and len(optional_string) > MAX_SIZE:
68-
optional_string = optional_string[: MAX_SIZE - 3] + "+++" # Can't use dots with pytest
67+
if isinstance(optional_string, str):
68+
optional_string = optional_string.replace("\n", " ")
69+
if len(optional_string) > MAX_SIZE:
70+
optional_string = optional_string[: MAX_SIZE - 3] + "+++" # Can't use dots with pytest
6971

7072
return optional_string

tests/unit/aggregation/test_values.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
TrimmedMean,
3030
UPGrad,
3131
UPGradWeighting,
32+
Weighting,
3233
)
3334

3435
J_base = tensor([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]])
@@ -116,7 +117,7 @@ def test_aggregator_output(A: Aggregator, J: Tensor, expected_output: Tensor):
116117

117118

118119
@mark.parametrize(["W", "G", "expected_output"], WEIGHTING_PARAMETRIZATIONS)
119-
def test_weighting_output(W: Aggregator, G: Tensor, expected_output: Tensor):
120+
def test_weighting_output(W: Weighting, G: Tensor, expected_output: Tensor):
120121
"""Test that the output values of a weighting are fixed (on cpu)."""
121122

122123
assert_close(W(G), expected_output, rtol=0, atol=1e-4)

tests/unit/autogram/test_engine.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,11 +301,13 @@ def test_compute_gramian_various_output_shapes(
301301
assert_close(autogram_gramian, expected_gramian, rtol=1e-4, atol=1e-5)
302302

303303

304-
def _non_empty_subsets(elements: set) -> list[set]:
304+
def _non_empty_subsets(S: set) -> list[list]:
305305
"""
306-
Generates the list of subsets of the given set, excluding the empty set.
306+
Generates the list of subsets of the given set, excluding the empty set. The sets are returned
307+
in the form of sorted lists so that the order is always the same, to make the parametrization of
308+
the test reproducible.
307309
"""
308-
return [set(c) for r in range(1, len(elements) + 1) for c in combinations(elements, r)]
310+
return [sorted(set(c)) for r in range(1, len(S) + 1) for c in combinations(S, r)]
309311

310312

311313
@mark.parametrize("gramian_module_names", _non_empty_subsets({"fc0", "fc1", "fc2", "fc3", "fc4"}))

0 commit comments

Comments
 (0)