Skip to content

Commit a7803f9

Browse files
committed
Move find_spec to top of file
1 parent 3b3f795 commit a7803f9

2 files changed

Lines changed: 13 additions & 13 deletions

File tree

src/torchjd/aggregation/cagrad.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
from importlib.util import find_spec
22

3+
# Check that the clarabel solver is installed
4+
if find_spec("clarabel") is None:
5+
raise ModuleNotFoundError(
6+
"CAGrad requires the clarabel solver, but it is not installed. Please run"
7+
"`pip install torchjd[cagrad]`."
8+
)
9+
310
import cvxpy as cp
411
import numpy as np
512
import torch
@@ -8,13 +15,6 @@
815
from ._gramian_utils import compute_gramian, normalize
916
from .bases import _WeightedAggregator, _Weighting
1017

11-
# Check that the clarabel solver is installed
12-
if find_spec("clarabel") is None:
13-
raise ModuleNotFoundError(
14-
"CAGrad requires the clarabel solver, but it is not installed. Please run"
15-
"`pip install torchjd[cagrad]`."
16-
)
17-
1818

1919
class CAGrad(_WeightedAggregator):
2020
"""

src/torchjd/aggregation/nash_mtl.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,19 @@
2626

2727
from importlib.util import find_spec
2828

29-
import cvxpy as cp
30-
import numpy as np
31-
import torch
32-
from cvxpy import Expression
33-
from torch import Tensor
34-
3529
# Check that the ecos solver is installed
3630
if find_spec("ecos") is None:
3731
raise ModuleNotFoundError(
3832
"NashMTL requires the ecos solver, but it is not installed. Please run"
3933
"`pip install torchjd[nash_mtl]`."
4034
)
4135

36+
import cvxpy as cp
37+
import numpy as np
38+
import torch
39+
from cvxpy import Expression
40+
from torch import Tensor
41+
4242
from .bases import _WeightedAggregator, _Weighting
4343

4444

0 commit comments

Comments
 (0)