1414 RequirementError ,
1515)
1616from torchjd .autojac ._transform .aggregate import _AggregateMatrices , _Matrixify , _Reshape
17+ from torchjd .autojac ._transform .ordered_set import OrderedSet
1718
1819from ._dict_assertions import assert_tensor_dicts_are_close
1920
@@ -54,7 +55,7 @@ def test_aggregate_matrices_output_structure(jacobian_matrices: JacobianMatrices
5455 output of the desired structure.
5556 """
5657
57- aggregate_matrices = _AggregateMatrices (Random (), key_order = _keys )
58+ aggregate_matrices = _AggregateMatrices (Random (), key_order = OrderedSet ( _keys ) )
5859 gradient_vectors = aggregate_matrices (jacobian_matrices )
5960
6061 assert set (jacobian_matrices .keys ()) == set (gradient_vectors .keys ())
@@ -66,7 +67,7 @@ def test_aggregate_matrices_output_structure(jacobian_matrices: JacobianMatrices
6667def test_aggregate_matrices_empty_dict ():
6768 """Tests that applying _AggregateMatrices to an empty input gives an empty output."""
6869
69- aggregate_matrices = _AggregateMatrices (Random (), key_order = [] )
70+ aggregate_matrices = _AggregateMatrices (Random (), key_order = OrderedSet ([]) )
7071 gradient_vectors = aggregate_matrices (JacobianMatrices ({}))
7172 assert len (gradient_vectors ) == 0
7273
@@ -158,7 +159,7 @@ def test_aggregate_matrices_check_keys():
158159 key1 = torch .tensor ([1.0 ])
159160 key2 = torch .tensor ([2.0 ])
160161 key3 = torch .tensor ([2.0 ])
161- aggregate = _AggregateMatrices (Random (), [key2 , key1 ])
162+ aggregate = _AggregateMatrices (Random (), OrderedSet ( [key2 , key1 ]) )
162163
163164 output_keys = aggregate .check_keys ({key1 , key2 })
164165 assert output_keys == {key1 , key2 }
0 commit comments