Skip to content

Commit 24e9438

Browse files
authored
chore: Make some aggregator dependencies optional (#327)
* Add optional dependency groups nash_mtl, cagrad and full * Remove cxvpy and ecos from the default dependencies * Change tests.yml to also install the full dependency group * Update documentation about how to install torchjd in README.md, installation.md, cagrad.py and nash_mtl.py * Add changelog entry
1 parent d5ed6f3 commit 24e9438

File tree

7 files changed

+46
-4
lines changed

7 files changed

+46
-4
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
with:
2323
python-version: ${{ matrix.python-version }}
2424
- name: Install default and test dependencies
25-
run: pdm install --group test --frozen-lockfile
25+
run: pdm install --group full --group test --frozen-lockfile
2626
- name: Run unit and doc tests with coverage report
2727
run: pdm run pytest tests/unit tests/doc --cov=src --cov-report=xml
2828
- name: Upload results to Codecov

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@ changes that do not affect the user.
88

99
## [Unreleased]
1010

11+
### Changed
12+
13+
- **BREAKING**: Changed the dependencies of `CAGrad` and `NashMTL` to be optional when installing
14+
TorchJD. Users of these aggregators will have to use `pip install torchjd[cagrad]`, `pip install
15+
torchjd[nash_mtl]` or `pip install torchjd[full]` to install TorchJD alongside those dependencies.
16+
This should make TorchJD more lightweight.
17+
1118
## [0.6.0] - 2025-04-19
1219

1320
### Added

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ TorchJD can be installed directly with pip:
5252
pip install torchjd
5353
```
5454
<!-- end installation -->
55+
Some aggregators may have additional dependencies. Please refer to the
56+
[installation documentation](https://torchjd.org/stable/installation) for them.
5557

5658
## Usage
5759
The main way to use TorchJD is to replace the usual call to `loss.backward()` by a call to

docs/source/installation.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,17 @@
66
```
77

88
Note that `torchjd` requires python 3.10, 3.11, 3.12 or 3.13 and `torch>=2.0`.
9+
10+
Some aggregators (CAGrad and Nash-MTL) have additional dependencies that are not included by default
11+
when installing `torchjd`. To install them, you can use:
12+
```
13+
pip install torchjd[cagrad]
14+
```
15+
```
16+
pip install torchjd[nash_mtl]
17+
```
18+
19+
To install `torchjd` with all of its optional dependencies, you can also use:
20+
```
21+
pip install torchjd[full]
22+
```

pyproject.toml

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ dependencies = [
1717
"quadprog>=0.1.9, != 0.1.10", # Doesn't work before 0.1.9, 0.1.10 is yanked
1818
"numpy>=1.21.0", # Does not work before 1.21
1919
"qpsolvers>=1.0.1", # Does not work before 1.0.1
20-
"cvxpy>=1.3.0", # No Clarabel solver before 1.3.0
21-
"ecos>=2.0.14", # Does not work before 2.0.14
2220
]
2321
classifiers = [
2422
"Development Status :: 4 - Beta",
@@ -67,3 +65,16 @@ plot = [
6765
"dash>=2.16.0", # Recent version to avoid problems, could be relaxed
6866
"kaleido==0.2.1", # Only works with locked version
6967
]
68+
69+
[project.optional-dependencies]
70+
nash_mtl = [
71+
"cvxpy>=1.3.0", # Could be relaxed
72+
"ecos>=2.0.14", # Does not work before 2.0.14
73+
]
74+
cagrad = [
75+
"cvxpy>=1.3.0", # No Clarabel solver before 1.3.0
76+
]
77+
full = [
78+
"cvxpy>=1.3.0", # No Clarabel solver before 1.3.0
79+
"ecos>=2.0.14", # Does not work before 2.0.14
80+
]

src/torchjd/aggregation/cagrad.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ class CAGrad(_WeightedAggregator):
2929
>>>
3030
>>> A(J)
3131
tensor([0.1835, 1.2041, 1.2041])
32+
33+
.. note::
34+
This aggregator has dependencies that are not included by default when installing
35+
``torchjd``. To install them, use ``pip install torchjd[cagrad]``.
3236
"""
3337

3438
def __init__(self, c: float, norm_eps: float = 0.0001):

src/torchjd/aggregation/nash_mtl.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,16 @@ class NashMTL(_WeightedAggregator):
6060
>>> A(J)
6161
tensor([0.0542, 0.7061, 0.7061])
6262
63+
.. note::
64+
This aggregator has dependencies that are not included by default when installing
65+
``torchjd``. To install them, use ``pip install torchjd[nash_mtl]``.
66+
6367
.. warning::
6468
This implementation was adapted from the `official implementation
6569
<https://github.com/AvivNavon/nash-mtl/tree/main>`_, which has some flaws. Use with caution.
6670
6771
.. warning::
68-
The aggregator is stateful. Its output will thus depend not only on the input matrix, but
72+
This aggregator is stateful. Its output will thus depend not only on the input matrix, but
6973
also on its state. It thus depends on previously seen matrices. It should be reset between
7074
experiments.
7175
"""

0 commit comments

Comments
 (0)