Skip to content

Commit e09f4d8

Browse files
committed
Fix usage of OrderedSet in _create_transform of mtl_backward
1 parent b07cc8f commit e09f4d8

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tests/unit/autojac/test_mtl_backward.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from torchjd import mtl_backward
77
from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad
8+
from torchjd.autojac._transform.ordered_set import OrderedSet
89
from torchjd.autojac.mtl_backward import _create_transform
910

1011

@@ -24,8 +25,8 @@ def test_check_create_transform():
2425
losses=[y1, y2],
2526
features=[f1, f2],
2627
aggregator=Mean(),
27-
tasks_params=[[p1], [p2]],
28-
shared_params={p0},
28+
tasks_params=[OrderedSet([p1]), OrderedSet([p2])],
29+
shared_params=OrderedSet([p0]),
2930
retain_graph=False,
3031
parallel_chunk_size=None,
3132
)

0 commit comments

Comments
 (0)