|
5 | 5 | from torch.overrides import is_tensor_like |
6 | 6 |
|
7 | 7 | from torchjd.autojac._transform._base import Transform |
8 | | -from torchjd.autojac._transform._diagonalize import Diagonalize |
9 | | -from torchjd.autojac._transform._init import Init |
10 | 8 | from torchjd.autojac._transform._jac import Jac |
11 | 9 | from torchjd.autojac._transform._ordered_set import OrderedSet |
12 | 10 | from torchjd.autojac._utils import ( |
13 | 11 | as_checked_ordered_set, |
14 | | - check_consistent_first_dimension, |
15 | | - check_matching_jac_shapes, |
16 | | - check_matching_length, |
17 | 12 | check_optional_positive_chunk_size, |
| 13 | + create_jac_dict, |
18 | 14 | ) |
19 | 15 |
|
20 | 16 |
|
@@ -159,38 +155,12 @@ def jac( |
159 | 155 | inputs_with_repetition = cast(Sequence[Tensor], (inputs,) if is_tensor_like(inputs) else inputs) |
160 | 156 | inputs_ = OrderedSet(inputs_with_repetition) |
161 | 157 |
|
162 | | - jac_outputs_dict = _create_jac_outputs_dict(outputs_, jac_outputs) |
| 158 | + jac_outputs_dict = create_jac_dict(outputs_, jac_outputs, "outputs", "jac_outputs") |
163 | 159 | transform = _create_transform(outputs_, inputs_, parallel_chunk_size, retain_graph) |
164 | 160 | result = transform(jac_outputs_dict) |
165 | 161 | return tuple(result[input] for input in inputs_with_repetition) |
166 | 162 |
|
167 | 163 |
|
168 | | -def _create_jac_outputs_dict( |
169 | | - outputs: OrderedSet[Tensor], |
170 | | - opt_jac_outputs: Sequence[Tensor] | Tensor | None, |
171 | | -) -> dict[Tensor, Tensor]: |
172 | | - """ |
173 | | - Creates a dictionary mapping outputs to their corresponding Jacobians. |
174 | | -
|
175 | | - :param outputs: The tensors to differentiate. |
176 | | - :param opt_jac_outputs: The initial Jacobians to backpropagate. If ``None``, defaults to |
177 | | - identity. |
178 | | - """ |
179 | | - if opt_jac_outputs is None: |
180 | | - # Transform that creates gradient outputs containing only ones. |
181 | | - init = Init(outputs) |
182 | | - # Transform that turns the gradients into Jacobians. |
183 | | - diag = Diagonalize(outputs) |
184 | | - return (diag << init)({}) |
185 | | - jac_outputs = cast( |
186 | | - Sequence[Tensor], (opt_jac_outputs,) if is_tensor_like(opt_jac_outputs) else opt_jac_outputs |
187 | | - ) |
188 | | - check_matching_length(jac_outputs, outputs, "jac_outputs", "outputs") |
189 | | - check_matching_jac_shapes(jac_outputs, outputs, "jac_outputs", "outputs") |
190 | | - check_consistent_first_dimension(jac_outputs, "jac_outputs") |
191 | | - return dict(zip(outputs, jac_outputs, strict=True)) |
192 | | - |
193 | | - |
194 | 164 | def _create_transform( |
195 | 165 | outputs: OrderedSet[Tensor], |
196 | 166 | inputs: OrderedSet[Tensor], |
|
0 commit comments