File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1+ from importlib .util import find_spec
2+
13import cvxpy as cp
24import numpy as np
35import torch
68from ._gramian_utils import compute_gramian , normalize
79from .bases import _WeightedAggregator , _Weighting
810
11+ # Check that the clarabel solver is installed
12+ if find_spec ("clarabel" ) is None :
13+ raise ModuleNotFoundError (
14+ "CAGrad requires the clarabel solver, but it is not installed. Please run"
15+ "`pip install torchjd[cagrad]`."
16+ )
17+
918
1019class CAGrad (_WeightedAggregator ):
1120 """
@@ -31,8 +40,10 @@ class CAGrad(_WeightedAggregator):
3140 tensor([0.1835, 1.2041, 1.2041])
3241
3342 .. note::
34- This aggregator has dependencies that are not included by default when installing
35- ``torchjd``. To install them, use ``pip install torchjd[cagrad]``.
43+ This aggregator is not installed by default. When not installed, trying to import it should
44+ result in the following error:
45+ ``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``.
46+ To install it, use ``pip install torchjd[cagrad]``.
3647 """
3748
3849 def __init__ (self , c : float , norm_eps : float = 0.0001 ):
Original file line number Diff line number Diff line change 2424# SOFTWARE.
2525
2626
27+ from importlib .util import find_spec
28+
2729import cvxpy as cp
2830import numpy as np
2931import torch
3032from cvxpy import Expression
3133from torch import Tensor
3234
35+ # Check that the ecos solver is installed
36+ if find_spec ("ecos" ) is None :
37+ raise ModuleNotFoundError (
38+ "NashMTL requires the ecos solver, but it is not installed. Please run"
39+ "`pip install torchjd[nash_mtl]`."
40+ )
41+
3342from .bases import _WeightedAggregator , _Weighting
3443
3544
@@ -61,8 +70,10 @@ class NashMTL(_WeightedAggregator):
6170 tensor([0.0542, 0.7061, 0.7061])
6271
6372 .. note::
64- This aggregator has dependencies that are not included by default when installing
65- ``torchjd``. To install them, use ``pip install torchjd[nash_mtl]``.
73+ This aggregator is not installed by default. When not installed, trying to import it should
74+ result in the following error:
75+ ``ImportError: cannot import name 'NashMTL' from 'torchjd.aggregation'``.
76+ To install it, use ``pip install torchjd[nash_mtl]``.
6677
6778 .. warning::
6879 This implementation was adapted from the `official implementation
You can’t perform that action at this time.
0 commit comments