We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent ff01c83 commit c9877ddCopy full SHA for c9877dd
1 file changed
src/torchjd/aggregation/__init__.py
@@ -1,6 +1,5 @@
1
from .aligned_mtl import AlignedMTL
2
from .bases import Aggregator
3
-from .cagrad import CAGrad
4
from .config import ConFIG
5
from .constant import Constant
6
from .dualproj import DualProj
@@ -9,9 +8,18 @@
9
8
from .krum import Krum
10
from .mean import Mean
11
from .mgda import MGDA
12
-from .nash_mtl import NashMTL
13
from .pcgrad import PCGrad
14
from .random import Random
15
from .sum import Sum
16
from .trimmed_mean import TrimmedMean
17
from .upgrad import UPGrad
+
+try:
18
+ from .cagrad import CAGrad
19
+except ImportError: # The required dependencies are not installed
20
+ pass
21
22
23
+ from .nash_mtl import NashMTL
24
25
0 commit comments