Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions src/torchjd/autojac/_transform/diagonalize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Iterable

import torch
from torch import Tensor

Expand All @@ -9,29 +7,28 @@


class Diagonalize(Transform[Gradients, Jacobians]):
def __init__(self, considered: Iterable[Tensor]):
self.considered = OrderedSet(considered)
def __init__(self, key_order: OrderedSet[Tensor]):
self.key_order = key_order
self.indices: list[tuple[int, int]] = []
begin = 0
for tensor in self.considered:
for tensor in self.key_order:
end = begin + tensor.numel()
self.indices.append((begin, end))
begin = end

def __call__(self, tensors: Gradients) -> Jacobians:
flattened_considered_values = [tensors[key].reshape([-1]) for key in self.considered]
flattened_considered_values = [tensors[key].reshape([-1]) for key in self.key_order]
diagonal_matrix = torch.cat(flattened_considered_values).diag()
diagonalized_tensors = {
key: diagonal_matrix[:, begin:end].reshape((-1,) + key.shape)
for (begin, end), key in zip(self.indices, self.considered)
for (begin, end), key in zip(self.indices, self.key_order)
}
return Jacobians(diagonalized_tensors)

def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
considered = set(self.considered)
if not considered == input_keys:
if not set(self.key_order) == input_keys:
raise RequirementError(
f"The input_keys must match the considered keys. Found input_keys {input_keys} and"
f"considered keys {considered}."
f"The input_keys must match the key_order. Found input_keys {input_keys} and"
f"key_order {self.key_order}."
)
return considered
return input_keys
2 changes: 1 addition & 1 deletion src/torchjd/autojac/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _create_transform(
init = Init(tensors)

# Transform that turns the gradients into Jacobians.
diag = Diagonalize(tensors)
diag = Diagonalize(OrderedSet(tensors))

# Transform that computes the required Jacobians.
jac = Jac(tensors, inputs, parallel_chunk_size, retain_graph)
Expand Down
11 changes: 6 additions & 5 deletions tests/unit/autojac/_transform/test_diagonalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pytest import raises

from torchjd.autojac._transform import Diagonalize, Gradients, RequirementError
from torchjd.autojac._transform.ordered_set import OrderedSet

from ._dict_assertions import assert_tensor_dicts_are_close

Expand All @@ -13,7 +14,7 @@ def test_single_input():
value = torch.ones_like(key)
input = Gradients({key: value})

diag = Diagonalize([key])
diag = Diagonalize(OrderedSet([key]))

output = diag(input)
expected_output = {key: torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])}
Expand All @@ -32,7 +33,7 @@ def test_multiple_inputs():
value3 = torch.ones_like(key3)
input = Gradients({key1: value1, key2: value2, key3: value3})

diag = Diagonalize([key1, key2, key3])
diag = Diagonalize(OrderedSet([key1, key2, key3]))

output = diag(input)
expected_output = {
Expand Down Expand Up @@ -88,8 +89,8 @@ def test_permute_order():
value2 = torch.ones_like(key2)
input = Gradients({key1: value1, key2: value2})

permuted_diag = Diagonalize([key2, key1])
diag = Diagonalize([key1, key2])
permuted_diag = Diagonalize(OrderedSet([key2, key1]))
diag = Diagonalize(OrderedSet([key1, key2]))

permuted_output = permuted_diag(input)
output = {key1: permuted_output[key2], key2: permuted_output[key1]} # un-permute
Expand All @@ -106,7 +107,7 @@ def test_check_keys():

key1 = torch.tensor([1.0])
key2 = torch.tensor([1.0])
diag = Diagonalize([key1])
diag = Diagonalize(OrderedSet([key1]))

output_keys = diag.check_keys({key1})
assert output_keys == {key1}
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/autojac/_transform/test_interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Stack,
TensorDict,
)
from torchjd.autojac._transform.ordered_set import OrderedSet

from ._dict_assertions import assert_tensor_dicts_are_close

Expand All @@ -35,7 +36,7 @@ def test_jac_is_stack_of_grads():
input = Gradients({y1: torch.ones_like(y1), y2: torch.ones_like(y2)})

jac = Jac(outputs=[y1, y2], inputs=[a1, a2], chunk_size=None, retain_graph=True)
diag = Diagonalize([y1, y2])
diag = Diagonalize(OrderedSet([y1, y2]))
jac_diag = jac << diag

grad1 = Grad(outputs=[y1], inputs=[a1, a2])
Expand Down Expand Up @@ -101,7 +102,7 @@ def test_multiple_differentiations():
def test_str():
"""Tests that the __str__ method works correctly even for a complex transform."""
init = Init([])
diag = Diagonalize([])
diag = Diagonalize(OrderedSet([]))
jac = Jac([], [], chunk_size=None)
transform = jac << diag << init

Expand Down