Skip to content

Commit 309c8f1

Browse files
committed
Make cagrad and nashmtl raise an error if their solver is not installed.
1 parent c9877dd commit 309c8f1

2 files changed

Lines changed: 26 additions & 4 deletions

File tree

src/torchjd/aggregation/cagrad.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from importlib.util import find_spec
2+
13
import cvxpy as cp
24
import numpy as np
35
import torch
@@ -6,6 +8,13 @@
68
from ._gramian_utils import compute_gramian, normalize
79
from .bases import _WeightedAggregator, _Weighting
810

11+
# Check that the clarabel solver is installed
12+
if find_spec("clarabel") is None:
13+
raise ModuleNotFoundError(
14+
"CAGrad requires the clarabel solver, but it is not installed. Please run"
15+
"`pip install torchjd[cagrad]`."
16+
)
17+
918

1019
class CAGrad(_WeightedAggregator):
1120
"""
@@ -31,8 +40,10 @@ class CAGrad(_WeightedAggregator):
3140
tensor([0.1835, 1.2041, 1.2041])
3241
3342
.. note::
34-
This aggregator has dependencies that are not included by default when installing
35-
``torchjd``. To install them, use ``pip install torchjd[cagrad]``.
43+
This aggregator is not installed by default. When not installed, trying to import it should
44+
result in the following error:
45+
``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``.
46+
To install it, use ``pip install torchjd[cagrad]``.
3647
"""
3748

3849
def __init__(self, c: float, norm_eps: float = 0.0001):

src/torchjd/aggregation/nash_mtl.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,21 @@
2424
# SOFTWARE.
2525

2626

27+
from importlib.util import find_spec
28+
2729
import cvxpy as cp
2830
import numpy as np
2931
import torch
3032
from cvxpy import Expression
3133
from torch import Tensor
3234

35+
# Check that the ecos solver is installed
36+
if find_spec("ecos") is None:
37+
raise ModuleNotFoundError(
38+
"NashMTL requires the ecos solver, but it is not installed. Please run"
39+
"`pip install torchjd[nash_mtl]`."
40+
)
41+
3342
from .bases import _WeightedAggregator, _Weighting
3443

3544

@@ -61,8 +70,10 @@ class NashMTL(_WeightedAggregator):
6170
tensor([0.0542, 0.7061, 0.7061])
6271
6372
.. note::
64-
This aggregator has dependencies that are not included by default when installing
65-
``torchjd``. To install them, use ``pip install torchjd[nash_mtl]``.
73+
This aggregator is not installed by default. When not installed, trying to import it should
74+
result in the following error:
75+
``ImportError: cannot import name 'NashMTL' from 'torchjd.aggregation'``.
76+
To install it, use ``pip install torchjd[nash_mtl]``.
6677
6778
.. warning::
6879
This implementation was adapted from the `official implementation

0 commit comments

Comments
 (0)