|
1 | 1 | import random as rand |
2 | 2 | from contextlib import nullcontext |
| 3 | +from importlib.util import find_spec |
3 | 4 |
|
4 | 5 | import torch |
5 | 6 | from pytest import RaisesExc, fixture, mark |
@@ -30,16 +31,35 @@ def pytest_addoption(parser): |
30 | 31 | def pytest_configure(config): |
31 | 32 | config.addinivalue_line("markers", "slow: mark test as slow to run") |
32 | 33 | config.addinivalue_line("markers", "xfail_if_cuda: mark test as xfail if running on cuda") |
| 34 | + config.addinivalue_line( |
| 35 | + "markers", "xfail_if_cagrad_not_installed: mark test as xfail if CAGrad is not installed" |
| 36 | + ) |
| 37 | + config.addinivalue_line( |
| 38 | + "markers", |
| 39 | + "xfail_if_nashmtl_not_installed: mark test as xfail if NashMTL is not installed", |
| 40 | + ) |
33 | 41 |
|
34 | 42 |
|
35 | 43 | def pytest_collection_modifyitems(config, items): |
36 | 44 | skip_slow = mark.skip(reason="Slow test. Use --runslow to run it.") |
37 | 45 | xfail_cuda = mark.xfail(reason=f"Test expected to fail on {DEVICE}") |
| 46 | + |
| 47 | + # Check if optional dependencies are installed |
| 48 | + cagrad_installed = all(find_spec(name) is not None for name in ["cvxpy", "clarabel"]) |
| 49 | + nashmtl_installed = all(find_spec(name) is not None for name in ["cvxpy", "ecos"]) |
| 50 | + |
| 51 | + xfail_cagrad = mark.xfail(reason="CAGrad dependencies not installed") |
| 52 | + xfail_nashmtl = mark.xfail(reason="NashMTL dependencies not installed") |
| 53 | + |
38 | 54 | for item in items: |
39 | 55 | if "slow" in item.keywords and not config.getoption("--runslow"): |
40 | 56 | item.add_marker(skip_slow) |
41 | 57 | if "xfail_if_cuda" in item.keywords and str(DEVICE).startswith("cuda"): |
42 | 58 | item.add_marker(xfail_cuda) |
| 59 | + if "xfail_if_cagrad_not_installed" in item.keywords and not cagrad_installed: |
| 60 | + item.add_marker(xfail_cagrad) |
| 61 | + if "xfail_if_nashmtl_not_installed" in item.keywords and not nashmtl_installed: |
| 62 | + item.add_marker(xfail_nashmtl) |
43 | 63 |
|
44 | 64 |
|
45 | 65 | def pytest_make_parametrize_id(config, val, argname): |
|
0 commit comments