Skip to content

Commit c353713

Browse files
committed
Rename params to tensors in jac_to_grad
1 parent ce9231b commit c353713

1 file changed

Lines changed: 23 additions & 21 deletions

File tree

src/torchjd/autojac/_jac_to_grad.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,42 +9,44 @@
99
from ._accumulation import TensorWithJac, accumulate_grads
1010

1111

12-
def jac_to_grad(params: Iterable[Tensor], aggregator: Aggregator, retain_jac: bool = False) -> None:
12+
def jac_to_grad(
13+
tensors: Iterable[Tensor], aggregator: Aggregator, retain_jac: bool = False
14+
) -> None:
1315
"""
14-
Aggregates the Jacobians stored in the ``.jac`` fields of ``params`` and accumulates the result
16+
Aggregates the Jacobians stored in the ``.jac`` fields of ``tensors`` and accumulates the result
1517
into their ``.grad`` fields.
1618
17-
:param params: The parameters whose ``.jac`` fields should be aggregated. All Jacobians must
19+
:param tensors: The tensors whose ``.jac`` fields should be aggregated. All Jacobians must
1820
have the same first dimension (number of outputs).
1921
:param aggregator: The aggregator used to reduce the Jacobians into gradients.
20-
:param retain_jac: Whether to preserve the ``.jac`` fields of the parameters.
22+
:param retain_jac: Whether to preserve the ``.jac`` fields of the tensors.
2123
"""
2224

23-
params_ = list[TensorWithJac]()
24-
for p in params:
25-
if not hasattr(p, "jac"):
25+
tensors_ = list[TensorWithJac]()
26+
for t in tensors:
27+
if not hasattr(t, "jac"):
2628
raise ValueError(
2729
"Some `jac` fields were not populated. Did you use `autojac.backward` before"
2830
"calling `jac_to_grad`?"
2931
)
30-
p_ = cast(TensorWithJac, p)
31-
params_.append(p_)
32+
t_ = cast(TensorWithJac, t)
33+
tensors_.append(t_)
3234

33-
if len(params_) == 0:
35+
if len(tensors_) == 0:
3436
return
3537

36-
jacobians = [p.jac for p in params_]
38+
jacobians = [t.jac for t in tensors_]
3739

3840
if not all([jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:]]):
3941
raise ValueError("All Jacobians should have the same number of rows.")
4042

4143
jacobian_matrix = _unite_jacobians(jacobians)
4244
gradient_vector = aggregator(jacobian_matrix)
43-
gradients = _disunite_gradient(gradient_vector, jacobians, params_)
44-
accumulate_grads(params_, gradients)
45+
gradients = _disunite_gradient(gradient_vector, jacobians, tensors_)
46+
accumulate_grads(tensors_, gradients)
4547

4648
if not retain_jac:
47-
_free_jacs(params_)
49+
_free_jacs(tensors_)
4850

4951

5052
def _unite_jacobians(jacobians: list[Tensor]) -> Tensor:
@@ -54,7 +56,7 @@ def _unite_jacobians(jacobians: list[Tensor]) -> Tensor:
5456

5557

5658
def _disunite_gradient(
57-
gradient_vector: Tensor, jacobians: list[Tensor], params: list[TensorWithJac]
59+
gradient_vector: Tensor, jacobians: list[Tensor], tensors: list[TensorWithJac]
5860
) -> list[Tensor]:
5961
gradient_vectors = []
6062
start = 0
@@ -63,16 +65,16 @@ def _disunite_gradient(
6365
current_gradient_vector = gradient_vector[start:end]
6466
gradient_vectors.append(current_gradient_vector)
6567
start = end
66-
gradients = [g.view(param.shape) for param, g in zip(params, gradient_vectors, strict=True)]
68+
gradients = [g.view(t.shape) for t, g in zip(tensors, gradient_vectors, strict=True)]
6769
return gradients
6870

6971

70-
def _free_jacs(params: Iterable[TensorWithJac]) -> None:
72+
def _free_jacs(tensors: Iterable[TensorWithJac]) -> None:
7173
"""
72-
Deletes the ``.jac`` field of the provided parameters.
74+
Deletes the ``.jac`` field of the provided tensors.
7375
74-
:param params: The parameters whose ``.jac`` fields should be cleared.
76+
:param tensors: The tensors whose ``.jac`` fields should be cleared.
7577
"""
7678

77-
for p in params:
78-
del p.jac
79+
for t in tensors:
80+
del t.jac

0 commit comments

Comments
 (0)