Skip to content

Commit 0f815b5

Browse files
authored
fix(aggregation): Fix optional dependencies (#329)
* Change tests.yml to add a tests-default-install job * Rename jobs in tests.yml * Add try/except around nash_mtl and cagrad import in aggregation/__init__.py * Add _check_dependencies.py with _OptionalDepsNotInstalledError and check_dependencies_are_installed * Add call to check_dependencies_are_installed in cagrad.py and nash_mtl.py * Ignore E402 error in flake8 config
1 parent a0f9173 commit 0f815b5

File tree

6 files changed

+68
-8
lines changed

6 files changed

+68
-8
lines changed

.github/workflows/tests.yml

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ on:
77
- cron: '41 16 * * *' # Every day at 16:41 UTC (to avoid high load at exact hour values).
88

99
jobs:
10-
Testing:
10+
tests-full-install:
11+
name: Run tests with full install
1112
runs-on: ${{ matrix.os }}
1213
strategy:
1314
fail-fast: false # Ensure matrix jobs keep running even if one fails
@@ -30,6 +31,29 @@ jobs:
3031
with:
3132
token: ${{ secrets.CODECOV_TOKEN }}
3233

34+
tests-default-install:
35+
name: Run (most) tests with default install
36+
runs-on: ubuntu-latest
37+
steps:
38+
- uses: actions/checkout@v4
39+
- name: Set up PDM
40+
uses: pdm-project/setup-pdm@v4
41+
with:
42+
python-version: '3.13'
43+
- name: Install default (without any option) and test dependencies
44+
run: pdm install --group test --frozen-lockfile
45+
- name: Run unit and doc tests with coverage report
46+
run: |
47+
pdm run pytest tests/unit tests/doc \
48+
--ignore tests/unit/aggregation/test_cagrad.py \
49+
--ignore tests/unit/aggregation/test_nash_mtl.py \
50+
--ignore tests/doc/test_aggregation.py \
51+
--cov=src --cov-report=xml
52+
- name: Upload results to Codecov
53+
uses: codecov/codecov-action@v4
54+
with:
55+
token: ${{ secrets.CODECOV_TOKEN }}
56+
3357
build-doc:
3458
name: Build doc
3559
runs-on: ubuntu-latest

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ repos:
1414
hooks:
1515
- id: flake8 # Check style and syntax. Does not modify code, issues have to be solved manually.
1616
args: [
17-
'--ignore=E501,E203,W503', # Ignore line length problems, space after colon problems, line break occurring before a binary operator problems.
17+
'--ignore=E501,E203,W503,E402', # Ignore line length problems, space after colon problems, line break occurring before a binary operator problems, module level import not at top of file problems.
1818
'--per-file-ignores=*/__init__.py:F401', # Ignore module imported but unused problems in __init__.py files.
1919
]
2020

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
from ._check_dependencies import _OptionalDepsNotInstalledError
12
from .aligned_mtl import AlignedMTL
23
from .bases import Aggregator
3-
from .cagrad import CAGrad
44
from .config import ConFIG
55
from .constant import Constant
66
from .dualproj import DualProj
@@ -9,9 +9,18 @@
99
from .krum import Krum
1010
from .mean import Mean
1111
from .mgda import MGDA
12-
from .nash_mtl import NashMTL
1312
from .pcgrad import PCGrad
1413
from .random import Random
1514
from .sum import Sum
1615
from .trimmed_mean import TrimmedMean
1716
from .upgrad import UPGrad
17+
18+
try:
19+
from .cagrad import CAGrad
20+
except _OptionalDepsNotInstalledError: # The required dependencies are not installed
21+
pass
22+
23+
try:
24+
from .nash_mtl import NashMTL
25+
except _OptionalDepsNotInstalledError: # The required dependencies are not installed
26+
pass
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from importlib.util import find_spec
2+
3+
4+
class _OptionalDepsNotInstalledError(ModuleNotFoundError):
5+
pass
6+
7+
8+
def check_dependencies_are_installed(dependency_names: list[str]) -> None:
9+
"""
10+
Check that the required list of dependencies are installed.
11+
12+
This can be useful for Aggregators whose dependencies are optional when installing torchjd.
13+
"""
14+
15+
if any(find_spec(name) is None for name in dependency_names):
16+
raise _OptionalDepsNotInstalledError()

src/torchjd/aggregation/cagrad.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
from ._check_dependencies import check_dependencies_are_installed # noqa
2+
3+
check_dependencies_are_installed(["cvxpy", "clarabel"])
4+
15
import cvxpy as cp
26
import numpy as np
37
import torch
@@ -31,8 +35,10 @@ class CAGrad(_WeightedAggregator):
3135
tensor([0.1835, 1.2041, 1.2041])
3236
3337
.. note::
34-
This aggregator has dependencies that are not included by default when installing
35-
``torchjd``. To install them, use ``pip install torchjd[cagrad]``.
38+
This aggregator is not installed by default. When not installed, trying to import it should
39+
result in the following error:
40+
``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``.
41+
To install it, use ``pip install torchjd[cagrad]``.
3642
"""
3743

3844
def __init__(self, c: float, norm_eps: float = 0.0001):

src/torchjd/aggregation/nash_mtl.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2424
# SOFTWARE.
2525

26+
from ._check_dependencies import check_dependencies_are_installed
27+
28+
check_dependencies_are_installed(["cvxpy", "ecos"])
2629

2730
import cvxpy as cp
2831
import numpy as np
@@ -61,8 +64,10 @@ class NashMTL(_WeightedAggregator):
6164
tensor([0.0542, 0.7061, 0.7061])
6265
6366
.. note::
64-
This aggregator has dependencies that are not included by default when installing
65-
``torchjd``. To install them, use ``pip install torchjd[nash_mtl]``.
67+
This aggregator is not installed by default. When not installed, trying to import it should
68+
result in the following error:
69+
``ImportError: cannot import name 'NashMTL' from 'torchjd.aggregation'``.
70+
To install it, use ``pip install torchjd[nash_mtl]``.
6671
6772
.. warning::
6873
This implementation was adapted from the `official implementation

0 commit comments

Comments
 (0)