Skip to content

Commit 18fb2f4

Browse files
authored
Merge branch 'main' into ci/add-dependabot
2 parents cf3f195 + 3fb2e03 commit 18fb2f4

5 files changed

Lines changed: 53 additions & 72 deletions

File tree

src/torchjd/autojac/_backward.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
from collections.abc import Iterable, Sequence
2-
from typing import cast
32

43
from torch import Tensor
5-
from torch.overrides import is_tensor_like
64

7-
from ._transform import AccumulateJac, Diagonalize, Init, Jac, OrderedSet, Transform
5+
from ._transform import AccumulateJac, Jac, OrderedSet, Transform
86
from ._utils import (
97
as_checked_ordered_set,
10-
check_consistent_first_dimension,
11-
check_matching_jac_shapes,
12-
check_matching_length,
138
check_optional_positive_chunk_size,
9+
create_jac_dict,
1410
get_leaf_tensors,
1511
)
1612

@@ -120,37 +116,11 @@ def backward(
120116
else:
121117
inputs_ = OrderedSet(inputs)
122118

123-
jac_tensors_dict = _create_jac_tensors_dict(tensors_, jac_tensors)
119+
jac_tensors_dict = create_jac_dict(tensors_, jac_tensors, "tensors", "jac_tensors")
124120
transform = _create_transform(tensors_, inputs_, parallel_chunk_size, retain_graph)
125121
transform(jac_tensors_dict)
126122

127123

128-
def _create_jac_tensors_dict(
129-
tensors: OrderedSet[Tensor],
130-
opt_jac_tensors: Sequence[Tensor] | Tensor | None,
131-
) -> dict[Tensor, Tensor]:
132-
"""
133-
Creates a dictionary mapping tensors to their corresponding Jacobians.
134-
135-
:param tensors: The tensors to differentiate.
136-
:param opt_jac_tensors: The initial Jacobians to backpropagate. If ``None``, defaults to
137-
identity.
138-
"""
139-
if opt_jac_tensors is None:
140-
# Transform that creates gradient outputs containing only ones.
141-
init = Init(tensors)
142-
# Transform that turns the gradients into Jacobians.
143-
diag = Diagonalize(tensors)
144-
return (diag << init)({})
145-
jac_tensors = cast(
146-
Sequence[Tensor], (opt_jac_tensors,) if is_tensor_like(opt_jac_tensors) else opt_jac_tensors
147-
)
148-
check_matching_length(jac_tensors, tensors, "jac_tensors", "tensors")
149-
check_matching_jac_shapes(jac_tensors, tensors, "jac_tensors", "tensors")
150-
check_consistent_first_dimension(jac_tensors, "jac_tensors")
151-
return dict(zip(tensors, jac_tensors, strict=True))
152-
153-
154124
def _create_transform(
155125
tensors: OrderedSet[Tensor],
156126
inputs: OrderedSet[Tensor],

src/torchjd/autojac/_jac.py

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,12 @@
55
from torch.overrides import is_tensor_like
66

77
from torchjd.autojac._transform._base import Transform
8-
from torchjd.autojac._transform._diagonalize import Diagonalize
9-
from torchjd.autojac._transform._init import Init
108
from torchjd.autojac._transform._jac import Jac
119
from torchjd.autojac._transform._ordered_set import OrderedSet
1210
from torchjd.autojac._utils import (
1311
as_checked_ordered_set,
14-
check_consistent_first_dimension,
15-
check_matching_jac_shapes,
16-
check_matching_length,
1712
check_optional_positive_chunk_size,
13+
create_jac_dict,
1814
)
1915

2016

@@ -159,38 +155,12 @@ def jac(
159155
inputs_with_repetition = cast(Sequence[Tensor], (inputs,) if is_tensor_like(inputs) else inputs)
160156
inputs_ = OrderedSet(inputs_with_repetition)
161157

162-
jac_outputs_dict = _create_jac_outputs_dict(outputs_, jac_outputs)
158+
jac_outputs_dict = create_jac_dict(outputs_, jac_outputs, "outputs", "jac_outputs")
163159
transform = _create_transform(outputs_, inputs_, parallel_chunk_size, retain_graph)
164160
result = transform(jac_outputs_dict)
165161
return tuple(result[input] for input in inputs_with_repetition)
166162

167163

168-
def _create_jac_outputs_dict(
169-
outputs: OrderedSet[Tensor],
170-
opt_jac_outputs: Sequence[Tensor] | Tensor | None,
171-
) -> dict[Tensor, Tensor]:
172-
"""
173-
Creates a dictionary mapping outputs to their corresponding Jacobians.
174-
175-
:param outputs: The tensors to differentiate.
176-
:param opt_jac_outputs: The initial Jacobians to backpropagate. If ``None``, defaults to
177-
identity.
178-
"""
179-
if opt_jac_outputs is None:
180-
# Transform that creates gradient outputs containing only ones.
181-
init = Init(outputs)
182-
# Transform that turns the gradients into Jacobians.
183-
diag = Diagonalize(outputs)
184-
return (diag << init)({})
185-
jac_outputs = cast(
186-
Sequence[Tensor], (opt_jac_outputs,) if is_tensor_like(opt_jac_outputs) else opt_jac_outputs
187-
)
188-
check_matching_length(jac_outputs, outputs, "jac_outputs", "outputs")
189-
check_matching_jac_shapes(jac_outputs, outputs, "jac_outputs", "outputs")
190-
check_consistent_first_dimension(jac_outputs, "jac_outputs")
191-
return dict(zip(outputs, jac_outputs, strict=True))
192-
193-
194164
def _create_transform(
195165
outputs: OrderedSet[Tensor],
196166
inputs: OrderedSet[Tensor],

src/torchjd/autojac/_utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,41 @@ def check_matching_grad_shapes(
103103
)
104104

105105

106+
def create_jac_dict(
107+
tensors: OrderedSet[Tensor],
108+
opt_jacobians: Sequence[Tensor] | Tensor | None,
109+
tensor_param_name: str,
110+
jacobian_param_name: str,
111+
) -> dict[Tensor, Tensor]:
112+
"""
113+
Creates a dictionary mapping tensors to their corresponding Jacobians.
114+
115+
If ``opt_jacobians`` is ``None``, creates identity Jacobians using Init and Diagonalize
116+
transforms. Otherwise, validates the provided Jacobians and returns them as a dict.
117+
118+
:param tensors: The tensors to differentiate.
119+
:param opt_jacobians: The initial Jacobians to backpropagate. If ``None``, defaults to
120+
identity.
121+
:param tensor_param_name: The name of the tensor parameter for error messages.
122+
:param jacobian_param_name: The name of the jacobian parameter for error messages.
123+
"""
124+
from torchjd.autojac._transform._diagonalize import Diagonalize
125+
from torchjd.autojac._transform._init import Init
126+
127+
if opt_jacobians is None:
128+
init = Init(tensors)
129+
diag = Diagonalize(tensors)
130+
return (diag << init)({})
131+
132+
jacobians = cast(
133+
Sequence[Tensor], (opt_jacobians,) if is_tensor_like(opt_jacobians) else opt_jacobians
134+
)
135+
check_matching_length(jacobians, tensors, jacobian_param_name, tensor_param_name)
136+
check_matching_jac_shapes(jacobians, tensors, jacobian_param_name, tensor_param_name)
137+
check_consistent_first_dimension(jacobians, jacobian_param_name)
138+
return dict(zip(tensors, jacobians, strict=True))
139+
140+
106141
def check_consistent_first_dimension(
107142
jacobians: Sequence[Tensor],
108143
variable_name: str,

tests/unit/autojac/test_backward.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from utils.tensors import eye_, randn_, tensor_
55

66
from torchjd.autojac import backward
7-
from torchjd.autojac._backward import _create_jac_tensors_dict, _create_transform
7+
from torchjd.autojac._backward import _create_transform
88
from torchjd.autojac._transform import OrderedSet
9+
from torchjd.autojac._utils import create_jac_dict
910

1011

1112
@mark.parametrize("default_jac_tensors", [True, False])
@@ -22,9 +23,11 @@ def test_check_create_transform(default_jac_tensors: bool) -> None:
2223
None if default_jac_tensors else [tensor_([1.0, 0.0]), tensor_([0.0, 1.0])]
2324
)
2425

25-
jac_tensors = _create_jac_tensors_dict(
26+
jac_tensors = create_jac_dict(
2627
tensors=OrderedSet([y1, y2]),
27-
opt_jac_tensors=optional_jac_tensors,
28+
opt_jacobians=optional_jac_tensors,
29+
tensor_param_name="tensors",
30+
jacobian_param_name="jac_tensors",
2831
)
2932
transform = _create_transform(
3033
tensors=OrderedSet([y1, y2]),

tests/unit/autojac/test_jac.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from utils.tensors import eye_, randn_, tensor_
55

66
from torchjd.autojac import jac
7-
from torchjd.autojac._jac import _create_jac_outputs_dict, _create_transform
7+
from torchjd.autojac._jac import _create_transform
88
from torchjd.autojac._transform import OrderedSet
9+
from torchjd.autojac._utils import create_jac_dict
910

1011

1112
@mark.parametrize("default_jac_outputs", [True, False])
@@ -22,9 +23,11 @@ def test_check_create_transform(default_jac_outputs: bool) -> None:
2223
None if default_jac_outputs else [tensor_([1.0, 0.0]), tensor_([0.0, 1.0])]
2324
)
2425

25-
jac_outputs = _create_jac_outputs_dict(
26-
outputs=OrderedSet([y1, y2]),
27-
opt_jac_outputs=optional_jac_outputs,
26+
jac_outputs = create_jac_dict(
27+
tensors=OrderedSet([y1, y2]),
28+
opt_jacobians=optional_jac_outputs,
29+
tensor_param_name="outputs",
30+
jacobian_param_name="jac_outputs",
2831
)
2932
transform = _create_transform(
3033
outputs=OrderedSet([y1, y2]),

0 commit comments

Comments
 (0)