Skip to content

Commit b07cc8f

Browse files
committed
Fix usage of OrderedSet in _create_transform of backward
1 parent 59dbdf3 commit b07cc8f

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tests/unit/autojac/test_backward.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from torchjd import backward
77
from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad
8+
from torchjd.autojac._transform.ordered_set import OrderedSet
89
from torchjd.autojac.backward import _create_transform
910

1011

@@ -20,7 +21,7 @@ def test_check_create_transform():
2021
transform = _create_transform(
2122
tensors=[y1, y2],
2223
aggregator=Mean(),
23-
inputs={a1, a2},
24+
inputs=OrderedSet([a1, a2]),
2425
retain_graph=False,
2526
parallel_chunk_size=None,
2627
)

0 commit comments

Comments
 (0)