Skip to content

Commit 2049a74

Browse files
committed
refactor(autojac): Reorder create_jac_dict parameters
1 parent abf7fbc commit 2049a74

5 files changed

Lines changed: 6 additions & 6 deletions

File tree

src/torchjd/autojac/_backward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def backward(
116116
else:
117117
inputs_ = OrderedSet(inputs)
118118

119-
jac_tensors_dict = create_jac_dict(tensors_, jac_tensors, "jac_tensors", "tensors")
119+
jac_tensors_dict = create_jac_dict(tensors_, jac_tensors, "tensors", "jac_tensors")
120120
transform = _create_transform(tensors_, inputs_, parallel_chunk_size, retain_graph)
121121
transform(jac_tensors_dict)
122122

src/torchjd/autojac/_jac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def jac(
155155
inputs_with_repetition = cast(Sequence[Tensor], (inputs,) if is_tensor_like(inputs) else inputs)
156156
inputs_ = OrderedSet(inputs_with_repetition)
157157

158-
jac_outputs_dict = create_jac_dict(outputs_, jac_outputs, "jac_outputs", "outputs")
158+
jac_outputs_dict = create_jac_dict(outputs_, jac_outputs, "outputs", "jac_outputs")
159159
transform = _create_transform(outputs_, inputs_, parallel_chunk_size, retain_graph)
160160
result = transform(jac_outputs_dict)
161161
return tuple(result[input] for input in inputs_with_repetition)

src/torchjd/autojac/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def check_matching_grad_shapes(
106106
def create_jac_dict(
107107
tensors: OrderedSet[Tensor],
108108
opt_jacobians: Sequence[Tensor] | Tensor | None,
109-
jacobian_param_name: str,
110109
tensor_param_name: str,
110+
jacobian_param_name: str,
111111
) -> dict[Tensor, Tensor]:
112112
"""
113113
Creates a dictionary mapping tensors to their corresponding Jacobians.
@@ -118,8 +118,8 @@ def create_jac_dict(
118118
:param tensors: The tensors to differentiate.
119119
:param opt_jacobians: The initial Jacobians to backpropagate. If ``None``, defaults to
120120
identity.
121-
:param jacobian_param_name: The name of the jacobian parameter for error messages.
122121
:param tensor_param_name: The name of the tensor parameter for error messages.
122+
:param jacobian_param_name: The name of the jacobian parameter for error messages.
123123
"""
124124
from torchjd.autojac._transform._diagonalize import Diagonalize
125125
from torchjd.autojac._transform._init import Init

tests/unit/autojac/test_backward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def test_check_create_transform(default_jac_tensors: bool) -> None:
2626
jac_tensors = create_jac_dict(
2727
tensors=OrderedSet([y1, y2]),
2828
opt_jacobians=optional_jac_tensors,
29-
jacobian_param_name="jac_tensors",
3029
tensor_param_name="tensors",
30+
jacobian_param_name="jac_tensors",
3131
)
3232
transform = _create_transform(
3333
tensors=OrderedSet([y1, y2]),

tests/unit/autojac/test_jac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def test_check_create_transform(default_jac_outputs: bool) -> None:
2626
jac_outputs = create_jac_dict(
2727
tensors=OrderedSet([y1, y2]),
2828
opt_jacobians=optional_jac_outputs,
29-
jacobian_param_name="jac_outputs",
3029
tensor_param_name="outputs",
30+
jacobian_param_name="jac_outputs",
3131
)
3232
transform = _create_transform(
3333
outputs=OrderedSet([y1, y2]),

0 commit comments

Comments
 (0)