Skip to content

Commit bede311

Browse files
committed
Extract the Aggregator based logic of jac_to_grad
1 parent 07867c0 commit bede311

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

src/torchjd/autojac/_jac_to_grad.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,18 @@ def jac_to_grad(
7373
if not retain_jac:
7474
_free_jacs(tensors_)
7575

76+
gradients = _jacobian_based(aggregator, jacobians, tensors_)
77+
78+
accumulate_grads(tensors_, gradients)
79+
80+
81+
def _jacobian_based(
82+
aggregator: Aggregator, jacobians: list[Tensor], tensors: list[TensorWithJac]
83+
) -> list[Tensor]:
7684
jacobian_matrix = _unite_jacobians(jacobians)
7785
gradient_vector = aggregator(jacobian_matrix)
78-
gradients = _disunite_gradient(gradient_vector, jacobians, tensors_)
79-
accumulate_grads(tensors_, gradients)
86+
gradients = _disunite_gradient(gradient_vector, jacobians, tensors)
87+
return gradients
8088

8189

8290
def _unite_jacobians(jacobians: list[Tensor]) -> Tensor:

0 commit comments

Comments
 (0)