From bac102ae6eb0a64052b59c71bab62a12226808dc Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 31 Mar 2025 17:12:51 +0200 Subject: [PATCH 1/2] Fix usage of OrderedSet in _create_transform of backward --- tests/unit/autojac/test_backward.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index d8e57667c..4a19168a6 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -5,6 +5,7 @@ from torchjd import backward from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad +from torchjd.autojac._transform.ordered_set import OrderedSet from torchjd.autojac.backward import _create_transform @@ -20,7 +21,7 @@ def test_check_create_transform(): transform = _create_transform( tensors=[y1, y2], aggregator=Mean(), - inputs={a1, a2}, + inputs=OrderedSet([a1, a2]), retain_graph=False, parallel_chunk_size=None, ) From b85271badc2e5571f94f89fe28d94e9d689f066e Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 31 Mar 2025 17:12:59 +0200 Subject: [PATCH 2/2] Fix usage of OrderedSet in _create_transform of mtl_backward --- tests/unit/autojac/test_mtl_backward.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index d267f587b..ba0050572 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -5,6 +5,7 @@ from torchjd import mtl_backward from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad +from torchjd.autojac._transform.ordered_set import OrderedSet from torchjd.autojac.mtl_backward import _create_transform @@ -24,8 +25,8 @@ def test_check_create_transform(): losses=[y1, y2], features=[f1, f2], aggregator=Mean(), - tasks_params=[[p1], [p2]], - shared_params={p0}, + tasks_params=[OrderedSet([p1]), OrderedSet([p2])], + shared_params=OrderedSet([p0]), retain_graph=False, parallel_chunk_size=None, )