Skip to content

Commit 05fda7a

Browse files
committed
Make cagrad and nashmtl raise an error if their solver is not installed.
1 parent 930da92 commit 05fda7a

2 files changed

Lines changed: 18 additions & 0 deletions

File tree

src/torchjd/aggregation/cagrad.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import importlib.util
2+
13
import cvxpy as cp
24
import numpy as np
35
import torch
@@ -6,6 +8,13 @@
68
from ._gramian_utils import compute_gramian, normalize
79
from .bases import _WeightedAggregator, _Weighting
810

11+
# Check that the ecos solver is installed
12+
if importlib.util.find_spec("ecos") is None:
13+
raise ModuleNotFoundError(
14+
"NashMTL requires the ecos solver, but it is not installed. Please run"
15+
"`pip install torchjd[nash_mtl]`."
16+
)
17+
918

1019
class CAGrad(_WeightedAggregator):
1120
"""

src/torchjd/aggregation/nash_mtl.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,21 @@
2424
# SOFTWARE.
2525

2626

27+
import importlib.util
28+
2729
import cvxpy as cp
2830
import numpy as np
2931
import torch
3032
from cvxpy import Expression
3133
from torch import Tensor
3234

35+
# Check that the ecos solver is installed
36+
if importlib.util.find_spec("ecos") is None:
37+
raise ModuleNotFoundError(
38+
"NashMTL requires the ecos solver, but it is not installed. Please run"
39+
"`pip install torchjd[nash_mtl]`."
40+
)
41+
3342
from .bases import _WeightedAggregator, _Weighting
3443

3544

0 commit comments

Comments
 (0)