Skip to content

Commit 2f85f63

Browse files
committed
Fix key_order name and type in Diagonalize
1 parent 6c5559f commit 2f85f63

File tree

4 files changed

+19
-20
lines changed

4 files changed

+19
-20
lines changed
Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Iterable
2-
31
import torch
42
from torch import Tensor
53

@@ -9,29 +7,28 @@
97

108

119
class Diagonalize(Transform[Gradients, Jacobians]):
12-
def __init__(self, considered: Iterable[Tensor]):
13-
self.considered = OrderedSet(considered)
10+
def __init__(self, key_order: OrderedSet[Tensor]):
11+
self.key_order = OrderedSet(key_order)
1412
self.indices: list[tuple[int, int]] = []
1513
begin = 0
16-
for tensor in self.considered:
14+
for tensor in self.key_order:
1715
end = begin + tensor.numel()
1816
self.indices.append((begin, end))
1917
begin = end
2018

2119
def __call__(self, tensors: Gradients) -> Jacobians:
22-
flattened_considered_values = [tensors[key].reshape([-1]) for key in self.considered]
20+
flattened_considered_values = [tensors[key].reshape([-1]) for key in self.key_order]
2321
diagonal_matrix = torch.cat(flattened_considered_values).diag()
2422
diagonalized_tensors = {
2523
key: diagonal_matrix[:, begin:end].reshape((-1,) + key.shape)
26-
for (begin, end), key in zip(self.indices, self.considered)
24+
for (begin, end), key in zip(self.indices, self.key_order)
2725
}
2826
return Jacobians(diagonalized_tensors)
2927

3028
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
31-
considered = set(self.considered)
32-
if not considered == input_keys:
29+
if not set(self.key_order) == input_keys:
3330
raise RequirementError(
34-
f"The input_keys must match the considered keys. Found input_keys {input_keys} and"
35-
f"considered keys {considered}."
31+
f"The input_keys must match the key_order. Found input_keys {input_keys} and"
32+
f"key_order {self.key_order}."
3633
)
37-
return considered
34+
return input_keys

src/torchjd/autojac/backward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _create_transform(
103103
init = Init(tensors)
104104

105105
# Transform that turns the gradients into Jacobians.
106-
diag = Diagonalize(tensors)
106+
diag = Diagonalize(OrderedSet(tensors))
107107

108108
# Transform that computes the required Jacobians.
109109
jac = Jac(tensors, inputs, parallel_chunk_size, retain_graph)

tests/unit/autojac/_transform/test_diagonalize.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pytest import raises
33

44
from torchjd.autojac._transform import Diagonalize, Gradients, RequirementError
5+
from torchjd.autojac._transform.ordered_set import OrderedSet
56

67
from ._dict_assertions import assert_tensor_dicts_are_close
78

@@ -13,7 +14,7 @@ def test_single_input():
1314
value = torch.ones_like(key)
1415
input = Gradients({key: value})
1516

16-
diag = Diagonalize([key])
17+
diag = Diagonalize(OrderedSet([key]))
1718

1819
output = diag(input)
1920
expected_output = {key: torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])}
@@ -32,7 +33,7 @@ def test_multiple_inputs():
3233
value3 = torch.ones_like(key3)
3334
input = Gradients({key1: value1, key2: value2, key3: value3})
3435

35-
diag = Diagonalize([key1, key2, key3])
36+
diag = Diagonalize(OrderedSet([key1, key2, key3]))
3637

3738
output = diag(input)
3839
expected_output = {
@@ -88,8 +89,8 @@ def test_permute_order():
8889
value2 = torch.ones_like(key2)
8990
input = Gradients({key1: value1, key2: value2})
9091

91-
permuted_diag = Diagonalize([key2, key1])
92-
diag = Diagonalize([key1, key2])
92+
permuted_diag = Diagonalize(OrderedSet([key2, key1]))
93+
diag = Diagonalize(OrderedSet([key1, key2]))
9394

9495
permuted_output = permuted_diag(input)
9596
output = {key1: permuted_output[key2], key2: permuted_output[key1]} # un-permute
@@ -106,7 +107,7 @@ def test_check_keys():
106107

107108
key1 = torch.tensor([1.0])
108109
key2 = torch.tensor([1.0])
109-
diag = Diagonalize([key1])
110+
diag = Diagonalize(OrderedSet([key1]))
110111

111112
output_keys = diag.check_keys({key1})
112113
assert output_keys == {key1}

tests/unit/autojac/_transform/test_interactions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Stack,
1818
TensorDict,
1919
)
20+
from torchjd.autojac._transform.ordered_set import OrderedSet
2021

2122
from ._dict_assertions import assert_tensor_dicts_are_close
2223

@@ -35,7 +36,7 @@ def test_jac_is_stack_of_grads():
3536
input = Gradients({y1: torch.ones_like(y1), y2: torch.ones_like(y2)})
3637

3738
jac = Jac(outputs=[y1, y2], inputs=[a1, a2], chunk_size=None, retain_graph=True)
38-
diag = Diagonalize([y1, y2])
39+
diag = Diagonalize(OrderedSet([y1, y2]))
3940
jac_diag = jac << diag
4041

4142
grad1 = Grad(outputs=[y1], inputs=[a1, a2])
@@ -101,7 +102,7 @@ def test_multiple_differentiations():
101102
def test_str():
102103
"""Tests that the __str__ method works correctly even for a complex transform."""
103104
init = Init([])
104-
diag = Diagonalize([])
105+
diag = Diagonalize(OrderedSet([]))
105106
jac = Jac([], [], chunk_size=None)
106107
transform = jac << diag << init
107108

0 commit comments

Comments
 (0)