Skip to content

Commit e082c10

Browse files
authored
Merge branch 'main' into improve-aggregate-typing
2 parents 87f3900 + 0a61131 commit e082c10

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

tests/unit/autojac/test_backward.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from torchjd import backward
77
from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad
8+
from torchjd.autojac._transform.ordered_set import OrderedSet
89
from torchjd.autojac.backward import _create_transform
910

1011

@@ -20,7 +21,7 @@ def test_check_create_transform():
2021
transform = _create_transform(
2122
tensors=[y1, y2],
2223
aggregator=Mean(),
23-
inputs={a1, a2},
24+
inputs=OrderedSet([a1, a2]),
2425
retain_graph=False,
2526
parallel_chunk_size=None,
2627
)

tests/unit/autojac/test_mtl_backward.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from torchjd import mtl_backward
77
from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad
8+
from torchjd.autojac._transform.ordered_set import OrderedSet
89
from torchjd.autojac.mtl_backward import _create_transform
910

1011

@@ -24,8 +25,8 @@ def test_check_create_transform():
2425
losses=[y1, y2],
2526
features=[f1, f2],
2627
aggregator=Mean(),
27-
tasks_params=[[p1], [p2]],
28-
shared_params={p0},
28+
tasks_params=[OrderedSet([p1]), OrderedSet([p2])],
29+
shared_params=OrderedSet([p0]),
2930
retain_graph=False,
3031
parallel_chunk_size=None,
3132
)

0 commit comments

Comments
 (0)