diff --git a/src/torchjd/autojac/_transform/diagonalize.py b/src/torchjd/autojac/_transform/diagonalize.py index 903b5e591..a1283ba86 100644 --- a/src/torchjd/autojac/_transform/diagonalize.py +++ b/src/torchjd/autojac/_transform/diagonalize.py @@ -1,5 +1,3 @@ -from typing import Iterable - import torch from torch import Tensor @@ -9,29 +7,28 @@ class Diagonalize(Transform[Gradients, Jacobians]): - def __init__(self, considered: Iterable[Tensor]): - self.considered = OrderedSet(considered) + def __init__(self, key_order: OrderedSet[Tensor]): + self.key_order = key_order self.indices: list[tuple[int, int]] = [] begin = 0 - for tensor in self.considered: + for tensor in self.key_order: end = begin + tensor.numel() self.indices.append((begin, end)) begin = end def __call__(self, tensors: Gradients) -> Jacobians: - flattened_considered_values = [tensors[key].reshape([-1]) for key in self.considered] + flattened_considered_values = [tensors[key].reshape([-1]) for key in self.key_order] diagonal_matrix = torch.cat(flattened_considered_values).diag() diagonalized_tensors = { key: diagonal_matrix[:, begin:end].reshape((-1,) + key.shape) - for (begin, end), key in zip(self.indices, self.considered) + for (begin, end), key in zip(self.indices, self.key_order) } return Jacobians(diagonalized_tensors) def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: - considered = set(self.considered) - if not considered == input_keys: + if not set(self.key_order) == input_keys: raise RequirementError( - f"The input_keys must match the considered keys. Found input_keys {input_keys} and" - f"considered keys {considered}." + f"The input_keys must match the key_order. Found input_keys {input_keys} and" + f"key_order {self.key_order}." ) - return considered + return input_keys diff --git a/src/torchjd/autojac/backward.py b/src/torchjd/autojac/backward.py index 12aada35a..6478bd578 100644 --- a/src/torchjd/autojac/backward.py +++ b/src/torchjd/autojac/backward.py @@ -103,7 +103,7 @@ def _create_transform( init = Init(tensors) # Transform that turns the gradients into Jacobians. - diag = Diagonalize(tensors) + diag = Diagonalize(OrderedSet(tensors)) # Transform that computes the required Jacobians. jac = Jac(tensors, inputs, parallel_chunk_size, retain_graph) diff --git a/tests/unit/autojac/_transform/test_diagonalize.py b/tests/unit/autojac/_transform/test_diagonalize.py index cf29ec05c..95986948a 100644 --- a/tests/unit/autojac/_transform/test_diagonalize.py +++ b/tests/unit/autojac/_transform/test_diagonalize.py @@ -2,6 +2,7 @@ from pytest import raises from torchjd.autojac._transform import Diagonalize, Gradients, RequirementError +from torchjd.autojac._transform.ordered_set import OrderedSet from ._dict_assertions import assert_tensor_dicts_are_close @@ -13,7 +14,7 @@ def test_single_input(): value = torch.ones_like(key) input = Gradients({key: value}) - diag = Diagonalize([key]) + diag = Diagonalize(OrderedSet([key])) output = diag(input) expected_output = {key: torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])} @@ -32,7 +33,7 @@ def test_multiple_inputs(): value3 = torch.ones_like(key3) input = Gradients({key1: value1, key2: value2, key3: value3}) - diag = Diagonalize([key1, key2, key3]) + diag = Diagonalize(OrderedSet([key1, key2, key3])) output = diag(input) expected_output = { @@ -88,8 +89,8 @@ def test_permute_order(): value2 = torch.ones_like(key2) input = Gradients({key1: value1, key2: value2}) - permuted_diag = Diagonalize([key2, key1]) - diag = Diagonalize([key1, key2]) + permuted_diag = Diagonalize(OrderedSet([key2, key1])) + diag = Diagonalize(OrderedSet([key1, key2])) permuted_output = permuted_diag(input) output = {key1: permuted_output[key2], key2: permuted_output[key1]} # un-permute @@ -106,7 +107,7 @@ def test_check_keys(): key1 = torch.tensor([1.0]) key2 = torch.tensor([1.0]) - diag = Diagonalize([key1]) + diag = Diagonalize(OrderedSet([key1])) output_keys = diag.check_keys({key1}) assert output_keys == {key1} diff --git a/tests/unit/autojac/_transform/test_interactions.py b/tests/unit/autojac/_transform/test_interactions.py index ccebcba3d..ef5626ae5 100644 --- a/tests/unit/autojac/_transform/test_interactions.py +++ b/tests/unit/autojac/_transform/test_interactions.py @@ -17,6 +17,7 @@ Stack, TensorDict, ) +from torchjd.autojac._transform.ordered_set import OrderedSet from ._dict_assertions import assert_tensor_dicts_are_close @@ -35,7 +36,7 @@ def test_jac_is_stack_of_grads(): input = Gradients({y1: torch.ones_like(y1), y2: torch.ones_like(y2)}) jac = Jac(outputs=[y1, y2], inputs=[a1, a2], chunk_size=None, retain_graph=True) - diag = Diagonalize([y1, y2]) + diag = Diagonalize(OrderedSet([y1, y2])) jac_diag = jac << diag grad1 = Grad(outputs=[y1], inputs=[a1, a2]) @@ -101,7 +102,7 @@ def test_multiple_differentiations(): def test_str(): """Tests that the __str__ method works correctly even for a complex transform.""" init = Init([]) - diag = Diagonalize([]) + diag = Diagonalize(OrderedSet([])) jac = Jac([], [], chunk_size=None) transform = jac << diag << init