Skip to content

Commit a1658e2

Browse files
refactor(autojac): Make Aggregate typing more strict (#289)
* Change type of key_order parameter of Aggregate and _AggregateMatrices from Iterable[Tensor] to OrderedSet[Tensor] and update tests accordingly
1 parent 0a61131 commit a1658e2

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

src/torchjd/autojac/_transform/aggregate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import OrderedDict
2-
from typing import Hashable, Iterable, TypeVar
2+
from typing import Hashable, TypeVar
33

44
import torch
55
from torch import Tensor
@@ -15,7 +15,7 @@
1515

1616

1717
class Aggregate(Transform[Jacobians, Gradients]):
18-
def __init__(self, aggregator: Aggregator, key_order: Iterable[Tensor]):
18+
def __init__(self, aggregator: Aggregator, key_order: OrderedSet[Tensor]):
1919
matrixify = _Matrixify()
2020
aggregate_matrices = _AggregateMatrices(aggregator, key_order)
2121
reshape = _Reshape()
@@ -31,7 +31,7 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
3131

3232

3333
class _AggregateMatrices(Transform[JacobianMatrices, GradientVectors]):
34-
def __init__(self, aggregator: Aggregator, key_order: Iterable[Tensor]):
34+
def __init__(self, aggregator: Aggregator, key_order: OrderedSet[Tensor]):
3535
self.key_order = OrderedSet(key_order)
3636
self.aggregator = aggregator
3737

tests/unit/autojac/_transform/test_aggregate.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
RequirementError,
1515
)
1616
from torchjd.autojac._transform.aggregate import _AggregateMatrices, _Matrixify, _Reshape
17+
from torchjd.autojac._transform.ordered_set import OrderedSet
1718

1819
from ._dict_assertions import assert_tensor_dicts_are_close
1920

@@ -54,7 +55,7 @@ def test_aggregate_matrices_output_structure(jacobian_matrices: JacobianMatrices
5455
output of the desired structure.
5556
"""
5657

57-
aggregate_matrices = _AggregateMatrices(Random(), key_order=_keys)
58+
aggregate_matrices = _AggregateMatrices(Random(), key_order=OrderedSet(_keys))
5859
gradient_vectors = aggregate_matrices(jacobian_matrices)
5960

6061
assert set(jacobian_matrices.keys()) == set(gradient_vectors.keys())
@@ -66,7 +67,7 @@ def test_aggregate_matrices_output_structure(jacobian_matrices: JacobianMatrices
6667
def test_aggregate_matrices_empty_dict():
6768
"""Tests that applying _AggregateMatrices to an empty input gives an empty output."""
6869

69-
aggregate_matrices = _AggregateMatrices(Random(), key_order=[])
70+
aggregate_matrices = _AggregateMatrices(Random(), key_order=OrderedSet([]))
7071
gradient_vectors = aggregate_matrices(JacobianMatrices({}))
7172
assert len(gradient_vectors) == 0
7273

@@ -158,7 +159,7 @@ def test_aggregate_matrices_check_keys():
158159
key1 = torch.tensor([1.0])
159160
key2 = torch.tensor([2.0])
160161
key3 = torch.tensor([2.0])
161-
aggregate = _AggregateMatrices(Random(), [key2, key1])
162+
aggregate = _AggregateMatrices(Random(), OrderedSet([key2, key1]))
162163

163164
output_keys = aggregate.check_keys({key1, key2})
164165
assert output_keys == {key1, key2}

0 commit comments

Comments
 (0)