We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
jac_to_grad
1 parent 07867c0 commit bede311Copy full SHA for bede311
src/torchjd/autojac/_jac_to_grad.py
@@ -73,10 +73,18 @@ def jac_to_grad(
73
if not retain_jac:
74
_free_jacs(tensors_)
75
76
+ gradients = _jacobian_based(aggregator, jacobians, tensors_)
77
+
78
+ accumulate_grads(tensors_, gradients)
79
80
81
+def _jacobian_based(
82
+ aggregator: Aggregator, jacobians: list[Tensor], tensors: list[TensorWithJac]
83
+) -> list[Tensor]:
84
jacobian_matrix = _unite_jacobians(jacobians)
85
gradient_vector = aggregator(jacobian_matrix)
- gradients = _disunite_gradient(gradient_vector, jacobians, tensors_)
- accumulate_grads(tensors_, gradients)
86
+ gradients = _disunite_gradient(gradient_vector, jacobians, tensors)
87
+ return gradients
88
89
90
def _unite_jacobians(jacobians: list[Tensor]) -> Tensor:
0 commit comments