Skip to content

Commit fff23ae

Browse files
committed
Add xfail_if_xyz_not_installed markers and use them instead of ignoring entire test files
1 parent 36fcb92 commit fff23ae

6 files changed

Lines changed: 46 additions & 25 deletions

File tree

.github/actions/run-tests/action.yml

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,6 @@ name: Run Tests
22
description: Run unit and doc tests with coverage
33

44
inputs:
5-
ignores:
6-
description: 'Space-separated list of test files to ignore'
7-
required: false
8-
default: ''
95
dtype:
106
description: 'Torch dtype to use for tests'
117
required: false
@@ -24,18 +20,4 @@ runs:
2420
PYTEST_TORCH_DTYPE: ${{ inputs.dtype }}
2521
PYTEST_TORCH_DEVICE: ${{ inputs.device }}
2622
run: |
27-
# Build the pytest command
28-
cmd="uv run pytest -W error tests/unit tests/doc"
29-
30-
# Add ignore flags for each file
31-
if [ -n "${{ inputs.ignores }}" ]; then
32-
for file in ${{ inputs.ignores }}; do
33-
cmd="$cmd --ignore $file"
34-
done
35-
fi
36-
37-
# Add coverage options
38-
cmd="$cmd --cov=src --cov-report=xml"
39-
40-
# Execute the command
41-
eval $cmd
23+
uv run pytest -W error tests/unit tests/doc --cov=src --cov-report=xml

.github/workflows/tests.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ jobs:
4141
token: ${{ secrets.CODECOV_TOKEN }}
4242

4343
tests-default-install:
44-
name: Run (most) tests with default install
44+
name: Run tests with default install
4545
runs-on: ubuntu-latest
4646
steps:
4747
- name: Checkout repository
@@ -54,8 +54,6 @@ jobs:
5454
groups: test
5555

5656
- uses: ./.github/actions/run-tests
57-
with:
58-
ignores: tests/unit/aggregation/test_cagrad.py tests/unit/aggregation/test_nash_mtl.py tests/doc/test_aggregation.py
5957

6058
- *upload-codecov
6159

tests/conftest.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import random as rand
22
from contextlib import nullcontext
3+
from importlib.util import find_spec
34

45
import torch
56
from pytest import RaisesExc, fixture, mark
@@ -30,16 +31,35 @@ def pytest_addoption(parser):
3031
def pytest_configure(config):
3132
config.addinivalue_line("markers", "slow: mark test as slow to run")
3233
config.addinivalue_line("markers", "xfail_if_cuda: mark test as xfail if running on cuda")
34+
config.addinivalue_line(
35+
"markers", "xfail_if_cagrad_not_installed: mark test as xfail if CAGrad is not installed"
36+
)
37+
config.addinivalue_line(
38+
"markers",
39+
"xfail_if_nashmtl_not_installed: mark test as xfail if NashMTL is not installed",
40+
)
3341

3442

3543
def pytest_collection_modifyitems(config, items):
3644
skip_slow = mark.skip(reason="Slow test. Use --runslow to run it.")
3745
xfail_cuda = mark.xfail(reason=f"Test expected to fail on {DEVICE}")
46+
47+
# Check if optional dependencies are installed
48+
cagrad_installed = all(find_spec(name) is not None for name in ["cvxpy", "clarabel"])
49+
nashmtl_installed = all(find_spec(name) is not None for name in ["cvxpy", "ecos"])
50+
51+
xfail_cagrad = mark.xfail(reason="CAGrad dependencies not installed")
52+
xfail_nashmtl = mark.xfail(reason="NashMTL dependencies not installed")
53+
3854
for item in items:
3955
if "slow" in item.keywords and not config.getoption("--runslow"):
4056
item.add_marker(skip_slow)
4157
if "xfail_if_cuda" in item.keywords and str(DEVICE).startswith("cuda"):
4258
item.add_marker(xfail_cuda)
59+
if "xfail_if_cagrad_not_installed" in item.keywords and not cagrad_installed:
60+
item.add_marker(xfail_cagrad)
61+
if "xfail_if_nashmtl_not_installed" in item.keywords and not nashmtl_installed:
62+
item.add_marker(xfail_nashmtl)
4363

4464

4565
def pytest_make_parametrize_id(config, val, argname):

tests/unit/aggregation/test_cagrad.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from ._asserts import assert_expected_structure, assert_non_conflicting, assert_non_differentiable
1111
from ._inputs import scaled_matrices, typical_matrices
1212

13+
pytestmark = mark.xfail_if_cagrad_not_installed
14+
1315
scaled_pairs = [(CAGrad(c=0.5), matrix) for matrix in scaled_matrices]
1416
typical_pairs = [(CAGrad(c=0.5), matrix) for matrix in typical_matrices]
1517
requires_grad_pairs = [(CAGrad(c=0.5), ones_(3, 5, requires_grad=True))]

tests/unit/aggregation/test_nash_mtl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from ._asserts import assert_expected_structure, assert_non_differentiable
99
from ._inputs import nash_mtl_matrices
1010

11+
pytestmark = mark.xfail_if_nashmtl_not_installed
12+
1113

1214
def _make_aggregator(matrix: Tensor) -> NashMTL:
1315
return NashMTL(n_tasks=matrix.shape[0])

tests/unit/aggregation/test_values.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,22 @@
8787
try:
8888
from torchjd.aggregation import CAGrad, CAGradWeighting
8989

90-
AGGREGATOR_PARAMETRIZATIONS.append((CAGrad(c=0.5), J_base, tensor([0.1835, 1.2041, 1.2041])))
91-
WEIGHTING_PARAMETRIZATIONS.append((CAGradWeighting(c=0.5), G_base, tensor([0.7041, 0.5000])))
90+
AGGREGATOR_PARAMETRIZATIONS.append(
91+
param(
92+
CAGrad(c=0.5),
93+
J_base,
94+
tensor([0.1835, 1.2041, 1.2041]),
95+
marks=mark.xfail_if_cagrad_not_installed,
96+
)
97+
)
98+
WEIGHTING_PARAMETRIZATIONS.append(
99+
param(
100+
CAGradWeighting(c=0.5),
101+
G_base,
102+
tensor([0.7041, 0.5000]),
103+
marks=mark.xfail_if_cagrad_not_installed,
104+
)
105+
)
92106
except ImportError:
93107
pass
94108

@@ -100,7 +114,10 @@
100114
NashMTL(n_tasks=2),
101115
J_base,
102116
tensor([0.0542, 0.7061, 0.7061]),
103-
marks=mark.filterwarnings("ignore::UserWarning"),
117+
marks=[
118+
mark.filterwarnings("ignore::UserWarning"),
119+
mark.xfail_if_nashmtl_not_installed,
120+
],
104121
)
105122
)
106123

0 commit comments

Comments
 (0)